{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Deep Learning Model for Asian Barrier Options\n",
    "\n",
    "As shown in the previous notebook, there are a few problems to generate data on the fly \n",
    "   \n",
    "    1. There is no model serialization so the trained model is not saved\n",
    "    2. There is no validation dataset to check the training progress\n",
    "    3. Most of the time is spent on Monte Carlo simulation hence the training is slow\n",
    "    4. We use a few paths(1024) for each option parameter set which is noise and the model cannot converge to a low cost value.\n",
    "The solution is to save the Monte Carlo simulation data on the disk. This allows us to\n",
    "\n",
    "    1. Reuse the same dataset for different models and save the Monte Carlo simulation time\n",
    "    2. Generate more accurate pricing data by increasing the number of paths\n",
    "    \n",
    "We will use CuPy to run the Monte Carlo simulation as it is the most efficient way. Taking the same OptionDataSet defined in the previous notebook:-"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from cupy_dataset import OptionDataSet"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Making the directories for the saved data files and the model check points:-"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir -p datafiles\n",
    "!mkdir -p check_points"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Defining a function to generate the dataset file:- "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "def gen_data(n_files = 630, options_per_file = 10000, seed=3):\n",
    "    counter = 0\n",
    "    ds = OptionDataSet(max_len=n_files * options_per_file, number_path=8192000, batch=1,\n",
    "                   seed=seed)\n",
    "    x = []\n",
    "    y = []\n",
    "    for i in ds:\n",
    "        if counter!=0 and counter % options_per_file == 0:\n",
    "            filename = 'datafiles/'+str(seed) + '_' + str(counter//options_per_file) + '.pth'\n",
    "            state = (torch.cat(x, 0), torch.cat(y, 0))\n",
    "            torch.save(state, filename)\n",
    "            x = []\n",
    "            y = []\n",
    "        x.append(i[0].cpu())\n",
    "        y.append(i[1].cpu())\n",
    "        counter += 1\n",
    "    return seed"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "It will generate files that contain `X` and `Y` matrix of size `option_per_file` and the filenames are in the format of `seed_group.pth`, we can test run with `n_files` = 5 and `options_per_file` = 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[1.7910e+02, 6.8079e+01, 1.0688e+02, 2.5889e-01, 1.7393e-01, 1.4359e-01],\n",
      "        [1.3597e+02, 5.8014e+01, 1.0772e+02, 1.1119e-01, 1.1278e-01, 3.3107e-03],\n",
      "        [4.7951e+01, 3.6957e+01, 8.0480e+01, 2.6536e-01, 5.3653e-02, 7.2782e-02],\n",
      "        [1.0026e+02, 8.1533e+00, 6.6216e+01, 3.8491e-02, 5.5396e-02, 1.4566e-01],\n",
      "        [1.0416e+02, 7.9586e+01, 1.0620e+02, 1.2557e-01, 1.9639e-02, 3.0966e-02],\n",
      "        [1.6851e+02, 9.7813e+01, 1.2468e+02, 1.1845e-01, 7.9473e-02, 1.0369e-01],\n",
      "        [1.6673e+02, 7.4595e+01, 6.4872e+01, 3.8445e-01, 4.0116e-02, 1.5097e-01],\n",
      "        [3.2400e+01, 1.4736e+01, 9.4934e+01, 2.5872e-01, 6.7174e-02, 1.0737e-01],\n",
      "        [1.2953e+02, 8.5337e+01, 1.2570e+02, 1.6452e-01, 7.1083e-02, 1.9993e-01],\n",
      "        [1.5920e+02, 1.3722e+02, 6.4502e+01, 3.5891e-01, 1.5036e-01, 1.8909e-01],\n",
      "        [4.7439e+00, 6.8898e-01, 1.7892e+01, 1.6206e-02, 1.1772e-01, 1.1536e-01],\n",
      "        [1.4590e+02, 5.5645e+00, 9.4114e+00, 9.8751e-02, 7.2455e-03, 1.2266e-01],\n",
      "        [1.0537e+02, 4.6149e+01, 7.2182e+01, 2.0814e-01, 1.5636e-02, 4.7667e-02],\n",
      "        [1.9498e+02, 1.4687e+02, 5.9092e+01, 5.9770e-02, 4.7395e-02, 8.9560e-02],\n",
      "        [5.4070e+00, 4.4146e+00, 1.3971e+02, 3.4593e-01, 1.8324e-01, 1.3890e-01],\n",
      "        [6.1022e+01, 3.5528e+01, 3.8339e+01, 1.4686e-01, 1.2386e-01, 1.2188e-01]])\n",
      "tensor([2.1621e-02, 1.0037e-02, 3.2299e+01, 0.0000e+00, 4.7080e+00, 2.7595e-04,\n",
      "        0.0000e+00, 5.9109e+01, 4.3838e+00, 0.0000e+00, 1.2694e+01, 0.0000e+00,\n",
      "        4.3242e-03, 0.0000e+00, 1.2877e+02, 2.6165e-06])\n"
     ]
    }
   ],
   "source": [
    "gen_data(n_files=5, options_per_file = 16, seed=3)\n",
    "X, Y = torch.load('datafiles/3_1.pth')\n",
    "print(X)\n",
    "print(Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will use DASK to generate dataset on multipe GPUs in this notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<table style=\"border: 2px solid white;\">\n",
       "<tr>\n",
       "<td style=\"vertical-align: top; border: 0px solid white\">\n",
       "<h3 style=\"text-align: left;\">Client</h3>\n",
       "<ul style=\"text-align: left; list-style: none; margin: 0; padding: 0;\">\n",
       "  <li><b>Scheduler: </b>tcp://127.0.0.1:43453</li>\n",
       "  <li><b>Dashboard: </b><a href='http://127.0.0.1:8787/status' target='_blank'>http://127.0.0.1:8787/status</a>\n",
       "</ul>\n",
       "</td>\n",
       "<td style=\"vertical-align: top; border: 0px solid white\">\n",
       "<h3 style=\"text-align: left;\">Cluster</h3>\n",
       "<ul style=\"text-align: left; list-style:none; margin: 0; padding: 0;\">\n",
       "  <li><b>Workers: </b>8</li>\n",
       "  <li><b>Cores: </b>8</li>\n",
       "  <li><b>Memory: </b>540.94 GB</li>\n",
       "</ul>\n",
       "</td>\n",
       "</tr>\n",
       "</table>"
      ],
      "text/plain": [
       "<Client: 'tcp://127.0.0.1:43453' processes=8 threads=8, memory=540.94 GB>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import dask\n",
    "import dask_cudf\n",
    "from dask.delayed import delayed\n",
    "from dask_cuda import LocalCUDACluster\n",
    "cluster = LocalCUDACluster()\n",
    "from dask.distributed import Client\n",
    "client = Client(cluster)\n",
    "client"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Following code is an example that generates `100x5x16` data points on 4 GPUs. For serious Deep Learning model training, we need millions of data points. You can try to change `n_files` and `options_per_file` to larger numbers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0,\n",
       " 1,\n",
       " 2,\n",
       " 3,\n",
       " 4,\n",
       " 5,\n",
       " 6,\n",
       " 7,\n",
       " 8,\n",
       " 9,\n",
       " 10,\n",
       " 11,\n",
       " 12,\n",
       " 13,\n",
       " 14,\n",
       " 15,\n",
       " 16,\n",
       " 17,\n",
       " 18,\n",
       " 19,\n",
       " 20,\n",
       " 21,\n",
       " 22,\n",
       " 23,\n",
       " 24,\n",
       " 25,\n",
       " 26,\n",
       " 27,\n",
       " 28,\n",
       " 29,\n",
       " 30,\n",
       " 31,\n",
       " 32,\n",
       " 33,\n",
       " 34,\n",
       " 35,\n",
       " 36,\n",
       " 37,\n",
       " 38,\n",
       " 39,\n",
       " 40,\n",
       " 41,\n",
       " 42,\n",
       " 43,\n",
       " 44,\n",
       " 45,\n",
       " 46,\n",
       " 47,\n",
       " 48,\n",
       " 49,\n",
       " 50,\n",
       " 51,\n",
       " 52,\n",
       " 53,\n",
       " 54,\n",
       " 55,\n",
       " 56,\n",
       " 57,\n",
       " 58,\n",
       " 59,\n",
       " 60,\n",
       " 61,\n",
       " 62,\n",
       " 63,\n",
       " 64,\n",
       " 65,\n",
       " 66,\n",
       " 67,\n",
       " 68,\n",
       " 69,\n",
       " 70,\n",
       " 71,\n",
       " 72,\n",
       " 73,\n",
       " 74,\n",
       " 75,\n",
       " 76,\n",
       " 77,\n",
       " 78,\n",
       " 79,\n",
       " 80,\n",
       " 81,\n",
       " 82,\n",
       " 83,\n",
       " 84,\n",
       " 85,\n",
       " 86,\n",
       " 87,\n",
       " 88,\n",
       " 89,\n",
       " 90,\n",
       " 91,\n",
       " 92,\n",
       " 93,\n",
       " 94,\n",
       " 95,\n",
       " 96,\n",
       " 97,\n",
       " 98,\n",
       " 99]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "futures = []\n",
    "for i in range(0, 100):\n",
    "    future = client.submit(gen_data, 5, 16, i)\n",
    "    futures.append(future)\n",
    "results = client.gather(futures)\n",
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once millions of data points are generated, we can combine the data points together and split them into training and validation datasets. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 / 350\n",
      "10 / 350\n",
      "20 / 350\n",
      "30 / 350\n",
      "40 / 350\n",
      "50 / 350\n",
      "60 / 350\n",
      "70 / 350\n",
      "80 / 350\n",
      "90 / 350\n",
      "100 / 350\n",
      "110 / 350\n",
      "120 / 350\n",
      "130 / 350\n",
      "140 / 350\n",
      "150 / 350\n",
      "160 / 350\n",
      "170 / 350\n",
      "180 / 350\n",
      "190 / 350\n",
      "200 / 350\n",
      "210 / 350\n",
      "220 / 350\n",
      "230 / 350\n",
      "240 / 350\n",
      "250 / 350\n",
      "260 / 350\n",
      "270 / 350\n",
      "280 / 350\n",
      "290 / 350\n",
      "300 / 350\n",
      "310 / 350\n",
      "320 / 350\n",
      "330 / 350\n",
      "340 / 350\n",
      "0 / 150\n",
      "10 / 150\n",
      "20 / 150\n",
      "30 / 150\n",
      "40 / 150\n",
      "50 / 150\n",
      "60 / 150\n",
      "70 / 150\n",
      "80 / 150\n",
      "90 / 150\n",
      "100 / 150\n",
      "110 / 150\n",
      "120 / 150\n",
      "130 / 150\n",
      "140 / 150\n"
     ]
    }
   ],
   "source": [
    "import pathlib\n",
    "\n",
    "files = list(pathlib.Path('datafiles/').glob('*.pth'))\n",
    "trn_size = int(len(files)*0.7)\n",
    "trn_files = files[:trn_size]\n",
    "val_files = files[trn_size:]\n",
    "\n",
    "trn_x = []\n",
    "trn_y = []\n",
    "count = 0\n",
    "\n",
    "for i in trn_files:\n",
    "    tensor = torch.load(i)\n",
    "    if count % 10 == 0:\n",
    "        print(count,'/',len(trn_files))\n",
    "    trn_x.append(tensor[0])\n",
    "    trn_y.append(tensor[1])\n",
    "    count += 1\n",
    "\n",
    "X = torch.cat(trn_x)\n",
    "Y = torch.cat(trn_y)\n",
    "torch.save((X,Y), 'trn.pth')\n",
    "\n",
    "val_x = []\n",
    "val_y = []\n",
    "count = 0\n",
    "\n",
    "for i in val_files:\n",
    "    tensor = torch.load(i)\n",
    "    if count % 10 == 0:\n",
    "        print(count,'/',len(val_files))\n",
    "    val_x.append(tensor[0])\n",
    "    val_y.append(tensor[1])\n",
    "    count += 1\n",
    "\n",
    "X = torch.cat(val_x)\n",
    "Y = torch.cat(val_y)\n",
    "torch.save((X,Y), 'val.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We created two data files `trn.pth` and `val.pth` for training and validation. We can define a new PyTorch Dataset to load data from file and write it to file. This dataset takes rank and world_size arguments for distributed training. It loads the whole dataset into the GPU memory and samples the data points according to the rank id so that dataset of different rank_id gives different data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Writing filedataset.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile filedataset.py\n",
    "import torch\n",
    "\n",
    "\n",
    "class OptionDataSet(torch.utils.data.Dataset):\n",
    "    def __init__(self, filename, rank=0, world_size=5):\n",
    "        tensor = torch.load(filename)\n",
    "        self.tensor = (tensor[0].cuda(), tensor[1].cuda())\n",
    "        self.length = len(self.tensor[0]) // world_size\n",
    "        self.world_size = world_size\n",
    "        self.rank = rank\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        index = index * self.world_size + self.rank\n",
    "        return self.tensor[0][index], self.tensor[1][index]\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.length"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "When training the deep learning models, one effective way to prevent over-fitting is to have separate validation dataset to monitor the out of sample performance. When the validation dataset  performance declines,  it means over-fitting is happening so we can  stop the training. We put everything together into one script that can train the model efficiently in multiple GPUs:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting distributed_training.py\n"
     ]
    }
   ],
   "source": [
    "%%writefile distributed_training.py\n",
    "import torch\n",
    "from ignite.engine import Engine, Events\n",
    "from torch.nn import MSELoss\n",
    "from ignite.contrib.handlers.param_scheduler import CosineAnnealingScheduler\n",
    "from apex import amp\n",
    "import argparse\n",
    "import os\n",
    "from apex.parallel import DistributedDataParallel\n",
    "import apex\n",
    "from apex.optimizers import FusedLAMB\n",
    "from model import Net\n",
    "from filedataset import OptionDataSet\n",
    "from ignite.metrics import MeanAbsoluteError\n",
    "import ignite\n",
    "import shutil\n",
    "import torch.distributed as dist\n",
    "\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument(\"--local_rank\", default=0, type=int)\n",
    "parser.add_argument(\"--path\", default=None)\n",
    "parser.add_argument(\"--mae_improv_tol\", default=0.002, type=float)\n",
    "args = parser.parse_args()\n",
    "\n",
    "args.distributed = False\n",
    "if 'WORLD_SIZE' in os.environ:\n",
    "    args.distributed = int(os.environ['WORLD_SIZE']) > 1\n",
    "\n",
    "if args.distributed:\n",
    "    torch.cuda.set_device(args.local_rank)\n",
    "    torch.distributed.init_process_group(backend='nccl',\n",
    "                                         init_method='env://')\n",
    "\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "trn_dataset = OptionDataSet(filename='./trn.pth',\n",
    "                            rank=dist.get_rank(),\n",
    "                            world_size=int(os.environ['WORLD_SIZE']))\n",
    "trn_dataset = torch.utils.data.DataLoader(trn_dataset,\n",
    "                                          batch_size=1024,\n",
    "                                          shuffle=True,\n",
    "                                          num_workers=0)\n",
    "\n",
    "val_dataset = OptionDataSet(filename='./val.pth',\n",
    "                            rank=dist.get_rank(),\n",
    "                            world_size=int(os.environ['WORLD_SIZE']))\n",
    "val_dataset = torch.utils.data.DataLoader(val_dataset,\n",
    "                                          batch_size=1024,\n",
    "                                          shuffle=False,\n",
    "                                          num_workers=0)\n",
    "\n",
    "model = Net().cuda()\n",
    "optimizer = FusedLAMB(model.parameters(), lr=1e-3)\n",
    "loss_fn = MSELoss()\n",
    "\n",
    "\n",
    "model = apex.parallel.convert_syncbn_model(model, channel_last=True)\n",
    "model, optimizer = amp.initialize(model, optimizer, opt_level='O1')\n",
    "\n",
    "\n",
    "best_mae = 100000\n",
    "\n",
    "if args.path is not None:\n",
    "    def resume():\n",
    "        global best_mae\n",
    "        checkpoint = torch.load(args.path)\n",
    "        best_mae = checkpoint['best_mae']\n",
    "        model.load_state_dict(checkpoint['state_dict'])\n",
    "        amp.load_state_dict(checkpoint['amp'])\n",
    "        optimizer.load_state_dict(checkpoint['optimizer'])\n",
    "    resume()\n",
    "\n",
    "\n",
    "if args.distributed:\n",
    "    model = DistributedDataParallel(model)\n",
    "    \n",
    "\n",
    "def train_update(engine, batch):\n",
    "    model.train()\n",
    "    optimizer.zero_grad()\n",
    "    x = batch[0]\n",
    "    y = batch[1]\n",
    "    y_pred = model(x)\n",
    "    loss = loss_fn(y, y_pred[:, 0])\n",
    "    with amp.scale_loss(loss, optimizer) as scaled_loss:\n",
    "        scaled_loss.backward()\n",
    "    optimizer.step()\n",
    "    return loss.item()\n",
    "\n",
    "trainer = Engine(train_update)\n",
    "log_interval = 500\n",
    "\n",
    "scheduler = CosineAnnealingScheduler(optimizer, 'lr', 1e-5, 5e-6,\n",
    "                                     len(trn_dataset),\n",
    "                                     start_value_mult=0.999, end_value_mult=0.999,\n",
    "                                     save_history=False\n",
    "                                     )\n",
    "trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)\n",
    "\n",
    "\n",
    "def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):\n",
    "    torch.save(state, filename)\n",
    "    if is_best:\n",
    "        shutil.copyfile(filename, 'check_points/model_best.pth.tar')\n",
    "\n",
    "\n",
    "@trainer.on(Events.ITERATION_COMPLETED)\n",
    "def log_training_loss(engine):\n",
    "    iter = (engine.state.iteration - 1) % len(trn_dataset) + 1\n",
    "    if iter % log_interval == 0:\n",
    "        print('loss', engine.state.output, 'iter', engine.state.iteration,\n",
    "              'lr', scheduler.get_param())\n",
    "\n",
    "\n",
    "metric = MeanAbsoluteError()\n",
    "loss_m = ignite.metrics.Loss(loss_fn)\n",
    "\n",
    "# run eval at one process only\n",
    "def eval_update(engine, batch):\n",
    "    model.eval()\n",
    "    x = batch[0]\n",
    "    y = batch[1]\n",
    "    y_pred = model(x)\n",
    "    return y, y_pred[:, 0]\n",
    "evaluator = Engine(eval_update)\n",
    "metric.attach(evaluator, \"MAE\")\n",
    "loss_m.attach(evaluator, \"loss\")\n",
    "        \n",
    "@trainer.on(Events.EPOCH_COMPLETED)\n",
    "def log_evalnumber(engine):\n",
    "    global best_mae\n",
    "    mae_improv_tol = args.mae_improv_tol  # default 0.002 or 0.2% improvement\n",
    "    evaluator.run(val_dataset, max_epochs=1)\n",
    "    metrics = evaluator.state.metrics\n",
    "    average_tensor = torch.tensor([metrics['MAE'], metrics['loss']]).cuda()\n",
    "    torch.distributed.reduce(average_tensor, 0, op=torch.distributed.ReduceOp.SUM)\n",
    "    torch.distributed.broadcast(average_tensor, 0)\n",
    "    average_tensor = average_tensor/int(os.environ['WORLD_SIZE'])\n",
    "\n",
    "    mae = average_tensor[0].item()\n",
    "    is_best = False\n",
    "    if (1 - mae / best_mae) >= mae_improv_tol or \\\n",
    "            (engine.state.epoch == engine.state.max_epochs and\n",
    "             mae < best_mae):\n",
    "        best_mae = mae\n",
    "        is_best = True\n",
    "\n",
    "    # print(\"RANK {}   Val Results - Epoch: {}  Avg MAE: {:.5f} loss: {:.5f} BEST MAE: {:.5f}\"\n",
    "    #      .format(dist.get_rank(), trainer.state.epoch, metrics['MAE'], metrics['loss'], best_mae))\n",
    "\n",
    "    if dist.get_rank() == 0:\n",
    "        print('Epoch {}/{}'.format(engine.state.epoch, engine.state.max_epochs))\n",
    "        print('Best MAE Improvement Tolerance for checkpointing: {}%'.format(100 * mae_improv_tol))\n",
    "        print(\"RANK {} AVG {} NGPUs, best-mae: {:.5f} mae: {:.5f} loss: {:.5f}\".format(\n",
    "            dist.get_rank(),\n",
    "            int(os.environ['WORLD_SIZE']),\n",
    "            best_mae,\n",
    "            average_tensor[0].item(),\n",
    "            average_tensor[1].item()))\n",
    "        fname = 'check_points/current_pth.tar'\n",
    "        if is_best:\n",
    "            save_checkpoint({'epoch': trainer.state.epoch,\n",
    "                             'state_dict': model.module.state_dict(),\n",
    "                             'best_mae': best_mae,\n",
    "                             'optimizer': optimizer.state_dict(),\n",
    "                             'amp': amp.state_dict()\n",
    "                             }, is_best,\n",
    "                            filename=fname)\n",
    "        inputs = torch.tensor([[110.0, 100.0, 120.0, 0.35, 0.1, 0.05]]).cuda()\n",
    "        res = model(inputs)\n",
    "        print('test one example:', res.item())\n",
    "\n",
    "trainer.run(trn_dataset, max_epochs=2000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Compared to the last notebook, it is a little complicated because \n",
    "* it handles the validation dataset evaluation\n",
    "* it serializes the model into a file and keeps track of the best performed model based on the mean absolute error(MAE)\n",
    "* it resumes the training from the file\n",
    "\n",
    "We can launch the distributed training by the following command:-"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ngpus=!echo $(nvidia-smi -L | wc -l)\n",
    "!python -m torch.distributed.launch --nproc_per_node={ngpus[0]} distributed_training.py"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need some patience to train the pricing model until it converges."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Inference and Greeks\n",
    "Once the training is converged, the best performed model is saved into `check_points/` directory. \n",
    "\n",
    "To get a good model, you need millions of data points to train the model until it converges. Usually it takes 10-20 hours in a single 8 GPUs DGX-1 machine. We trained the model with 10 million training data points and 5 million validation data points. We didn't explore what is the minimum number of training samples but simply use large number of data samples. You may get away by using less data points for training. \n",
    "\n",
    "To save your time, you can run the following commands to download the weights and use them for the inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset is already present. No need to re-download it.\n"
     ]
    }
   ],
   "source": [
    "! ((test ! -f './check_points/model_best.pth.tar' ||  test ! -f './check_points/512/model_best.pth.tar') && \\\n",
    "  bash ./download_data.sh) || echo \"Dataset is already present. No need to re-download it.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can load the model parameters and use it to do inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[18.7140]], device='cuda:0', grad_fn=<AddmmBackward>)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from model import Net\n",
    "import torch\n",
    "checkpoint = torch.load('check_points/model_best.pth.tar')\n",
    "model = Net().cuda()\n",
    "model.load_state_dict(checkpoint['state_dict'])\n",
    "inputs = torch.tensor([[110.0, 100.0, 120.0, 0.35, 0.1, 0.05]]).cuda()\n",
    "model(inputs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One of the benefits of building a deep learning model is that the [Greeks](<https://en.wikipedia.org/wiki/Greeks_(finance)#First-order_Greeks>) can be easily computed. \n",
    "We just need to take advantage of the auto-grad feature in Pytorch. Following is an example to compute the first order differentiation for a multiple variable polynomial function. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(tensor([24., 36.], grad_fn=<AddBackward0>),)\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import grad\n",
    "'''\n",
    "z = (xy)^2\n",
    "x = 3, y =2\n",
    "\n",
    "first order deriv [24 36]\n",
    "'''\n",
    "inputs = torch.tensor([3.0,2.0], requires_grad=True)\n",
    "z = (inputs[0]*inputs[1])**2\n",
    "first_order_grad = grad(z, inputs, create_graph=True)\n",
    "print(first_order_grad)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can use `grad` function to compute the first order differentiation for parameters 'K, B, S0, sigma, mu, r'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[-6.7092e-01, -2.1257e-02,  7.8896e-01,  1.9219e+01,  4.8331e+01,\n",
       "         -1.8419e+01]], device='cuda:0')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = torch.tensor([[110.0, 100.0, 120.0, 0.35, 0.1, 0.05]]).cuda()\n",
    "inputs.requires_grad = True\n",
    "x = model(inputs)\n",
    "x.backward()\n",
    "first_order_gradient = inputs.grad\n",
    "first_order_gradient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here we are going to plot the Delta graph:-"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fcd9e1ac7b8>]"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de5xcZZ3n8c+vqm/pdJJO0o3EXEgCAeUONiDiBRYcAyo4OCgZbzhodARvuM7i6AKLs+uos7rDLoJREUHloiJkxjjIuijKzQRCIBcCSbikQyCdC0l3uruqq+q3f9SppNJUdVW6+3RV1/m+X69+peo5p6p/ffqkvv08z7mYuyMiItEVq3QBIiJSWQoCEZGIUxCIiEScgkBEJOIUBCIiEVdX6QIOVltbm8+dO7fSZYiIjCuPPfbYdndvL7Rs3AXB3LlzWbFiRaXLEBEZV8zshWLLNDQkIhJxCgIRkYhTEIiIRJyCQEQk4hQEIiIRpyAQEYk4BYGISMQpCEQOwt5Eilsefp6BdKbSpYiMGgWByEG4+aHnueqeNfzqsc5KlyIyahQEIgdh6+4+AF7c2VvhSkRGj4JA5CB0dScAeG773gpXIjJ6FAQiB+HlPQoCqT0KApGD8HIwNPTc9r1kMrrft9QGBYFImVLpDF3dCdpaGkikMmzd01/pkkRGhYJApEw79ibJOJw2fzoAz3VpeEhqg4JApEyvBD2AE2ZNAaCrRz0CqQ2hBYGZ3WRm28xsdZHlHzKzJ83sKTN7yMxOCKsWkdHw8u7sB/8bDp0MwM69A5UsR2TUhNkjuBlYOMTy54B3uPtxwNeBJSHWIjJirwSHjh75uknEY8auvckKVyQyOkK7VaW7P2Bmc4dY/lDe00eAWWHVIjIadvdmP/inTqxnanM9OxQEUiOqZY7gUuC3xRaa2WIzW2FmK7q6usawLJH9uhMpGupiNNbFaW1u4NVeBYHUhooHgZmdRTYI/kuxddx9ibt3uHtHe3v72BUnkqenP8WkxmwnelpzA7sUBFIjKhoEZnY88EPgAnffUclaRErZm0jR0pQNgtbmenZpslhqRMWCwMzmAHcBH3H3ZypVh0i5ehIpWnI9gonqEUjtCG2y2MxuA84E2sysE7gaqAdw9xuBq4DpwPfMDCDl7h1h1SMyUt39+4OgNRgacneC/Vdk3ArzqKFFJZZ/AvhEWN9fZLT1JFLMmNIEwLSJ9Qyknb3J9L5wEBmvKj5ZLDJe5A8NtTY3AOhcAqkJCgKRMvX0758snpoLAs0TSA1QEIiUqTuRoqWxHsgODQHsVI9AaoCCQKQMyVSGZCpDS2Mc2D809GqvDiGV8U9BIFKGvYkUABPzTigD9QikNigIRMrQn0oDMKE+2yOYPKEeM3SZCakJCgKRMiQGMgA01mf/y8RjRuuEenYqCKQGKAhEypBMZ4OgIR7f1za1uYFdmiOQGqAgEClDMhUEQd3+/zJTJzboPAKpCQoCkTIkgjmCxvwgaK5Xj0BqgoJApAyJQj2CZvUIpDYoCETKkAuCxsFDQ5oslhqgIBApQ8E5guYGEqkMfcl0pcoSGRUKApEy7O8R5B81FFxmQr0CGecUBCJlSBYZGgJdgVTGPwWBSBkKHzWkK5BKbVAQiJSh8BxBdmhIh5DKeKcgEClDwTkCDQ1JjVAQiJShUI+gdUKuR6AgkPFNQSBShkQqTV3MiMf236i+Lh5jclOdegQy7ikIRMqQTGUO6A3kZE8q0xyBjG8KApEyJFKZA44YyslegVQ9AhnfQgsCM7vJzLaZ2eoiy83MrjOzDWb2pJmdHFYtIiNVtEfQXK8gkHEvzB7BzcDCIZafCywIvhYDN4RYi8iIZHsE8de0Zy9FraEhGd/qwnpjd3/AzOYOscoFwC3u7sAjZtZqZjPcfWtYNYkMV/EegYaGalUm4yTTGRKpDIlUmmQq+zj/32QqQzKdJpV2Mu6kM5B2x91JZ7Jf+e2ZA9qcjLPvcTqTfZ2Tbcstw/fX9Ob50znrDYeM+s8aWhCUYSawOe95Z9D2miAws8Vkew3MmTNnTIoTyZdIpWmIvzYIpk1soDeZpn8gTVP9a3sMEq5MxtndN8Ce/gH29KWCf1/7vCeRJpFK532Qp4MP8QyJgcH/pkmmMwykvXQBIYgZxMyImWEGtv9ANRrqYjUXBGVz9yXAEoCOjo7K/HYk0hKpzL77FedrDc4ufrV3gEOnKAhG0+7eATbv6qVzVy+du/ro6kmwvTvJ9p4EXd0Jtvck2LE3STpT/CMhZjB5Qj0TG+porI/RWBenoS5GYzzGxMY6psZjNNbHaIjHsu255XXZ54PbGuuy62ZfE6exPkZ9PEZdLPvBHY8Z8dj+D/J4zIjFjLgZsRjEB7cNarf8T/0xVMkg2ALMzns+K2gTqTqJVKZwjyC43tCOvQkOndI01mWNe33JNBu7erJf23rY0NXD89t72byrl+7+1AHr1seN9pZG2iY1cuiUJo6bOYW2SQ1Mn9hIa3M9k5vqmTyhnskT6vY9ntgQr9iH63hSySBYClxuZrcDpwG7NT8g1SqZyjA5OJM43/SWRgB29GieoJT+gTRrXtrNyhdfZeXmV3my81U6d/XhwR/0MYM505qZ1zaRjrlTmT21mdnTJjBrajMzWyfQ2lyvD/WQhBYEZnYbcCbQZmadwNVAPYC73wgsA84DNgC9wMfDqkVkpIr1CNpasj2C7T2JsS6p6u1NpHho4w4e3LCdlS/uYu3WPfvG3We2TuDE2a1c9KbZHN7ewhGHtDC3rbngkVkSvjCPGlpUYrkDl4X1/UVGUzKVLjhH0DYp2yNQEIC7s7FrL39Yv40/rO/iL8/tJJnO0NwQ54RZrXzibfM5aXYrJ85p5ZBJGkarJuNislik0hKpDI0FegSTGutoqIuxPaJDQ33JNA9v2s79T3dx//ptdO7qA2DBIS1ccsZczjyqnY7DphU89Faqh4JApAzJIkcNmWUnMLd3R6dH4O48+txO7ly+mWWrt9I/kGFCfZwzjmjj0+84nDOPamfW1OZKlykHQUEgUoZicwSQnSfoisDQUP9Aml893smP/vwcm7r2MqmxjgtPnsW5xx7KqfOmaXx/HFMQiJQh2yMo/EHX1tLIS7v7x7iisbOjJ8EtD7/ArY+8wM69SY6bOYV/uegE3n3cDCY06MO/FigIRMqQTBfvEbRPauTJLbvHuKLwberq4Qd/eo67Hu8kkcpwzhsP4RNvm89p86bpMM4aoyAQKSGVzpDOeMHLUEO2R7Bzb5JMxonFxv8H5JqXdvO9P2xk2VNbqY/HeP/Js7j0rfM44pCWSpcmIVEQiJSQTL/2NpX52loaSGecXb3JfSeYjUfb9vTzP5at4+4nXqKlsY5Pv+Nw/u6MebRPGr8/k5RHQSBSQmIgd+P6IkGw71yC8RkEqXSGWx5+ge/e9wyJVIbLzjqcxW8/nCkFzqSW2qQgEClhf4+g+GQxZE8qO4pJY1bXaFjz0m6+dOcqnn65m7cf2c5/O/8Y5rVNrHRZMsYUBCIllOwRBEHQNY7OJchknB/+eRPfvnc9U5sbuOFDJ7Pw2EM1CRxRCgKREpLpNFB8jqC9ZXxdZmLr7j6+dOcqHtq4g3cd8zq+ceHxTJvYUOmypIIUBCIl9A8MPVk8eUIdDfHYuDip7N+ffImv/no1A+kM/3zhcXzwlNnqBYiCQKSU3BxBsaEhM2N6SwPbu6v3ekPd/QNcfc8a7lq5hRNmt/K/Pnii5gJkHwWBSAmJEj0CyM4TVOvQ0NqX9vD3P3uMzTt7+fzZC7j8Px1BfZGT4ySaFAQiJezvERS/nEJbSwPbqnCy+M4Vm/mvd6+mtbmeOz51OqfMnVbpkqQKKQhESkgMZCeLiw0NQbZHsHbrnrEqqaT+gTRX37OGO1Zs5i2HT+e6RSftO7pJZDAFgUgJpeYIAF43uYntPUlS6Qx1FR522dGTYPGtj/HYC7u47KzDueKdRxGvgUtfSHgUBCIllDNHMGvqBNIZZ+vufmZPq9y1+Dds6+bjNy9n254E3/vQyZx33IyK1SLjh2aMREooZ44g9+Gfu0NXJTy4YTt//b2H6EtmuONTpysEpGzqEYiUkJsjKNUjAOjc1QtMH4uyDvAfq7fy2dtWMr+thR9d0qE7hMlBURCIlFDOHMGMKRMwq0yP4K7HO/nyL5/khFlT+PHHT9XF4uSgaWhIpIRy5gga6mLMmNzE5l29Y1UWAD995AWuuHMVp82bxq2XnqYQkGEJNQjMbKGZrTezDWZ2ZYHlc8zsfjNbaWZPmtl5YdYjMhzJdIaYQV2JI29mTWvmhR1jFwTf/+NGvnb3as5+wyHcdMkpTGxUB1+GJ7QgMLM4cD1wLnA0sMjMjh602teAO939JOBi4Hth1SMyXIlUhoa6WMlr8hw9YzLrtu4hk/FQ63F3vnPfM3zjt0/znuNncONH3kRTkfspi5QjzB7BqcAGd9/k7kngduCCQes4MDl4PAV4KcR6RIYlmcoMecRQzjGvn0xvMs2m7XtDq8Xd+affrOO63z/LBzpm8a8Xn6TLRciIhbkHzQQ25z3vDNryXQN82Mw6gWXAZwu9kZktNrMVZraiq6srjFpFikqk0kPOD+S86bCpADy8aUcodaQzzj/++il+9OfnuOQtc/nnC4/XiWIyKir9p8Qi4GZ3nwWcB9xqZq+pyd2XuHuHu3e0t7ePeZESbYlUZsgjhnLmtU1k9rQJ/OmZ0f9jZSCd4Yo7n+C2v2zmsrMO5+r3Hk1MISCjJMzZpS3A7Lzns4K2fJcCCwHc/WEzawLagG0h1iVyUHJzBKWYGaccNo0/bdiOu4/adf6TqQyX/fxx7lv7Cv+w8Cg+c+YRo/K+Ijlh9giWAwvMbJ6ZNZCdDF46aJ0XgbMBzOyNQBOgsR+pKslUhoYyx+GPnzWFru4EL+/pH5XvnUpn+MIdK7lv7Stc896jFQISitCCwN1TwOXAvcA6skcHrTGza83s/GC1LwGfNLNVwG3AJe4e7iEXIgcpkcrQWOZROSfMbgVg1ebdI/6+mYzzlbueYtlTL/O1d7+RS86YN+L3FCkk1AOP3X0Z2Ung/Lar8h6vBc4IswaRkUqm0jSW2SN444zJNMRjrHh+JwuPPXTY3zOTcb52z2p+8Vgnnzt7AZ942/xhv5dIKZWeLBapeslUhsb68v6rNNXHecsR07l37cvDPp9gT/8Al/5kOT9/9EU+c+bhfPGcBcN6H5FyKQhESkgcxBwBwPtOnMnmnX3ct+6Vg/5e23sSLFryCH96djvXXnAMX37XUbq5vIROQSBSwsH0CADec/wM5rVN5MpfPcnKF3eV/botr/bxgRsfZmNXDz/8WAcfPX2uQkDGhIJApISD7RHUxWP8+JJTqI/H+KffrCvrNRu7erjohofo6knw00tP48yjDhluuSIHTUEgUkK5l5jIN7dtIovfPp/HXtjF+pe7i67n7vxh/TYuuvFhkukMdyw+nQ7dYF7GmC5XKFJCuZeYGOzCk2fxrf9Yz21/eZFrzj9mX/u27n6uvmcNG7t66OpOsKt3gMOmN3Pzx09lXtvE0SxdpCwKApESkmVeYmKwaRMbeNexh/LrlVu48tw30FQfp6s7wcXff4Stu/t5+5FtvOmwaXQcNpX3nvD6YYWNyGhQEIiUUO4lJgpZdOps/m3VSyx5YBOHTW/mut8/y8t7+rn10lM1BCRVQ0EgMoR0xkll/KDnCHJOnz+d0+dP5zv3PQNkb3f540tOUQhIVVEQiAwhmSp9m8qhmBn/8oETuOWh53nz/OkccUgLs6fpxvJSXcoKAjNbAHyD7J3GmnLt7q7z3qWm5YJgOHMEOTNbJ/CV8944WiWJjLpy9+4fAzcAKeAs4Bbgp2EVJVItEqk0MPwegch4UO7ePcHdfw+Yu7/g7tcA7w6vLJHqkBiFHoFItSt3jiAR3DnsWTO7nOwNZlrCK0ukOiRGOEcgMh6Uu3d/HmgGPge8Cfgw8NGwihKpFqMxRyBS7crdu+e6e4+7d7r7x939/cCcMAsTqQa5OYLhHj4qMh6UGwRfKbNNpKaM9PBRkfFgyDkCMzsXOA+YaWbX5S2aTPYIIpGapsliiYJSk8UvAY8B5wf/5nQDXwyrKJFqoR6BRMGQQeDuq4BVZvbT4Gb0IpGyv0egOQKpXaWGhp4CPHj8muXufnw4ZYlUh2RaJ5RJ7Ss1NPSeMalCpEolBjRHILVvyL07OIv4BXd/IWhaEDzeBuws9eZmttDM1pvZBjO7ssg6HzCztWa2xsx+ftA/gUiIkmnNEUjtK/eic58EFgPTgMOBWcCNwNlDvCYOXA+8E+gElpvZUndfm7fOArKHoZ7h7rvMTDdqlaqiHoFEQbl792XAGcAeAHd/Fij1oX0qsMHdN7l7ErgduGDQOp8Ernf3XcH7biu3cJGxoB6BREG5e3ci+DAHwMzqCCaRhzAT2Jz3vDNoy3ckcKSZPWhmj5jZwjLrERkT+641FFcQSO0q96JzfzSzfwQmmNk7gc8A/zZK338BcCbZ4aYHzOw4d381fyUzW0x2aIo5c3RlCxk7uRvXFzpqTqRWlPtnzpVAF/AU8ClgGfC1Eq/ZAszOez4raMvXCSx19wF3fw54hmwwHMDdl7h7h7t3tLe3l1myyMgN98b1IuNJWT0Cd8+Y2d3A3e7eVeZ7LwcWmNk8sgFwMfC3g9a5G1gE/NjM2sgOFW0q8/1FQpdQEEgEDLmHW9Y1ZrYdWA+sN7MuM7uq1BsHZyJfDtwLrAPudPc1ZnatmZ0frHYvsMPM1gL3A1929x0j+YFERlO2R6CziqW2leoRfJHs0UKnBEM3mNl84AYz+6K7f3eoF7v7MrLDSPltV+U9duCK4Euk6iRSGR0xJDWv1B7+EWBRLgQA3H0TujGNRERiIK2hIal5pfbwenffPrgxmCeoD6ckkeqRSGVorNfQkNS2UkGQHOYykZrQrx6BRECpOYITzGxPgXYDmkKoR6SqJFIZJjWVe7qNyPhU6n4E6hNLpCVSGdo1NCQ1Tn1ekSFosliiQHu4yBASOo9AIkBBIDKERCpNU73+m0ht0x4uMoT+AfUIpPYpCESGkEilaVSPQGqc9nCRItIZZyDtNKlHIDVOQSBSRCKVBlCPQGqe9nCRInS/YokK7eEiReRuU9mkE8qkxikIRIroHwiGhtQjkBqnPVykiFyPQIePSq1TEIgUkZss1gllUuu0h4sU0T+gHoFEg4JApAgdPipRoT1cpIjc4aM6oUxqnYJApIh+9QgkIrSHixShE8okKrSHixShE8okKkINAjNbaGbrzWyDmV05xHrvNzM3s44w6xE5GDqhTKIitD3czOLA9cC5wNHAIjM7usB6k4DPA4+GVYvIcOiEMomKMP/UORXY4O6b3D0J3A5cUGC9rwPfBPpDrEXkoO07fFQ9AqlxYe7hM4HNec87g7Z9zOxkYLa7/2aoNzKzxWa2wsxWdHV1jX6lIgX0D2RoiMeIxazSpYiEqmJ/6phZDPgO8KVS67r7EnfvcPeO9vb28IsTIbg7mXoDEgFh7uVbgNl5z2cFbTmTgGOBP5jZ88CbgaWaMJZqkUhlaNQRQxIBYQbBcmCBmc0zswbgYmBpbqG773b3Nnef6+5zgUeA8919RYg1iZStP6kegURDaHu5u6eAy4F7gXXAne6+xsyuNbPzw/q+IqOlbyBNc4N6BFL76sJ8c3dfBiwb1HZVkXXPDLMWkYPVm1QQSDSo3ytSRF8yrbOKJRIUBCJFaGhIokJBIFJEbzLFBAWBRICCQKSI/oEME+pDnUYTqQoKApEiepMpDQ1JJCgIRIroG0hraEgiQUEgUkAm48HQkIJAap+CQKSA3G0q1SOQKFAQiBTQm8wGgeYIJAoUBCIF9AVBoBPKJAoUBCIF9A2oRyDRoSAQKSA3NKTJYokCBYFIAbmhIU0WSxQoCEQK6BtIAeoRSDQoCEQK6EtmAGhu0CUmpPYpCEQK6E2qRyDRoSAQKaB/QHMEEh0KApECejVZLBGiIBApIHcegYaGJAoUBCIF9CXTNNbFiMes0qWIhE5BIFJAb1KXoJboUBCIFNCTSDGpSYeOSjSEGgRmttDM1pvZBjO7ssDyK8xsrZk9aWa/N7PDwqxHpFzd/QO0NNZXugyRMRFaEJhZHLgeOBc4GlhkZkcPWm0l0OHuxwO/BL4VVj0iB6O7Xz0CiY4wewSnAhvcfZO7J4HbgQvyV3D3+929N3j6CDArxHpEytbdn2JSo4JAoiHMIJgJbM573hm0FXMp8NtCC8xssZmtMLMVXV1do1iiSGGaI5AoqYrJYjP7MNABfLvQcndf4u4d7t7R3t4+tsVJJHX3D9CiIJCICHNP3wLMzns+K2g7gJmdA3wVeIe7J0KsR6Qs7h70CDRZLNEQZo9gObDAzOaZWQNwMbA0fwUzOwn4PnC+u28LsRaRsiVSGQbSTovmCCQiQgsCd08BlwP3AuuAO919jZlda2bnB6t9G2gBfmFmT5jZ0iJvJzJmuvuzVx6drKEhiYhQ93R3XwYsG9R2Vd7jc8L8/iLD0ZPIBoHmCCQqqmKyWKSadPcPADBJJ5RJRCgIRAbp6VePQKJFQSAySHduaEiTxRIRCgKRQfb0ZYeGJuvwUYkIBYHIIDv3JgGY1tJQ4UpExoaCQGSQHXuTNNbFmKj7EUhEKAhEBtnek6CtpREz3Z1MokFBIDLIjp4k0zUsJBGiIBAZZMfeBNMnKggkOhQEIoO8sic7NCQSFQoCkTyJVJqu7gQzp06odCkiY0ZBIJJn66v9AMxsVRBIdCgIRPK89GofoCCQaFEQiOTpzAWBhoYkQhQEInk6d/VhBodOaap0KSJjRkEgkmfjth7mTGumsU5nFUt0KAhE8mzY1sOCQ1oqXYbImFIQiARS6QybtvdwuIJAIkZBIBJ4fsdeBtLOgkMmVboUkTGlIBAJrHzxVQBOnD2lwpWIjC0FgUjgic2vMqmpjvltGhqSaFEQiADuzsObdnDi7FZiMV1+WqIl1CAws4Vmtt7MNpjZlQWWN5rZHcHyR81sbpj1iBSz5qU9bOray7uOObTSpYiMudDuzm1mceB64J1AJ7DczJa6+9q81S4Fdrn7EWZ2MfBN4INh1SQC2b/+B9JOXzJNTzLF+pf38N37nmViQ5z3Hv/6SpcnMuZCCwLgVGCDu28CMLPbgQuA/CC4ALgmePxL4P+Ymbm7j3Yx96/fxjVL1xzQVui7OK9tHLxeudUV+jEKvbScOgqvU957FVrzNT9TwfcaQf3lvLbsn2l4dRT7XQ6kM2QGLWptrufr7zuWKc26Yb1ET5hBMBPYnPe8Ezit2DrunjKz3cB0YHv+Sma2GFgMMGfOnGEV0zqhnpNmt76mvdDtCAuOENvgpwVeV+CFhd6r8Hql36/wnRNHr46R/UzljauP7vcsXUeh19XHYzTVx2iqj9PSWMec6c2cPGcqTfU6m1iiKcwgGDXuvgRYAtDR0TGs3sJJc6Zy0pypo1qXiEgtCHOyeAswO+/5rKCt4DpmVgdMAXaEWJOIiAwSZhAsBxaY2TwzawAuBpYOWmcp8LHg8d8A/y+M+QERESkutKGhYMz/cuBeIA7c5O5rzOxaYIW7LwV+BNxqZhuAnWTDQkRExlCocwTuvgxYNqjtqrzH/cBFYdYgIiJD05nFIiIRpyAQEYk4BYGISMQpCEREIs7G29GaZtYFvFDpOopoY9BZ0VWo2mtUfSOj+kamlus7zN3bCy0Yd0FQzcxshbt3VLqOoVR7japvZFTfyES1Pg0NiYhEnIJARCTiFASja0mlCyhDtdeo+kZG9Y1MJOvTHIGISMSpRyAiEnEKAhGRiFMQDJOZzTaz+81srZmtMbPPB+3XmNkWM3si+DqvgjU+b2ZPBXWsCNqmmdl9ZvZs8G9F7tZjZkflbaMnzGyPmX2hktvPzG4ys21mtjqvreD2sqzrzGyDmT1pZidXqL5vm9nTQQ2/NrPWoH2umfXlbccbK1Rf0d+nmX0l2H7rzexdFarvjrzanjezJ4L2Smy/Yp8p4e+D7q6vYXwBM4CTg8eTgGeAo8neg/k/V7q+oK7ngbZBbd8CrgweXwl8swrqjAMvA4dVcvsBbwdOBlaX2l7AecBvyd4N883AoxWq76+AuuDxN/Pqm5u/XgW3X8HfZ/B/ZRXQCMwDNgLxsa5v0PL/CVxVwe1X7DMl9H1QPYJhcvet7v548LgbWEf2HszV7gLgJ8HjnwDvq2AtOWcDG929omeMu/sDZO+Lka/Y9roAuMWzHgFazWzGWNfn7r9z91Tw9BGydwKsiCLbr5gLgNvdPeHuzwEbgFNDK46h67Psza4/ANwWZg1DGeIzJfR9UEEwCsxsLnAS8GjQdHnQVbupUkMvAQd+Z2aPmdnioO117r41ePwy8LrKlHaAiznwP2C1bD8ovr1mApvz1uuk8n8I/B3ZvxBz5pnZSjP7o5m9rVJFUfj3WW3b723AK+7+bF5bxbbfoM+U0PdBBcEImVkL8CvgC+6+B7gBOBw4EdhKtrtZKW9195OBc4HLzOzt+Qs927+s6PHDlr2N6fnAL4Kmatp+B6iG7VWMmX0VSAE/C5q2AnPc/STgCuDnZja5AqVV7e9zkEUc+MdIxbZfgc+UfcLaBxUEI2Bm9WR/YT9z97sA3P0Vd0+7ewb4ASF3d4fi7luCf7cBvw5qeSXXfQz+3Vap+gLnAo+7+ytQXdsvUGx7bQFm5603K2gbc2Z2CfAe4EPBBwXBkMuO4PFjZMfgjxzr2ob4fVbT9qsDLgTuyLVVavsV+kxhDPZBBcEwBWOKPwLWuft38trzx+j+Glg9+LVjwcwmmtmk3GOyk4qrgaXAx4LVPgbcU4n68hzwl1i1bL88xbbXUuCjwZEbbwZ253Xfx4yZLQT+ATjf3Xvz2tvNLB48ng8sADZVoL5iv8+lwMVm1mhm84L6/jLW9QXOAZ52985cQyW2X7HPFMZiHxzLWfFa+gLeSraL9iTwRPB1HnAr8FTQvhSYUaH65pM9KmMVsAb4atA+Hfg98Czwf4FpFdyGE4EdwJS8toptP7KBtBUYIDveemmx7UX2SI3ryf6l+BTQUaH6NpAdJ87tgzcG6+XildwAAAIPSURBVL4/+L0/ATwOvLdC9RX9fQJfDbbfeuDcStQXtN8MfHrQupXYfsU+U0LfB3WJCRGRiNPQkIhIxCkIREQiTkEgIhJxCgIRkYhTEIiIRJyCQGSYzOxaMzun0nWIjJQOHxUZBjOLu3u60nWIjAb1CEQGCa5F/7SZ/czM1pnZL82sObhe/TfN7HHgIjO72cz+JnjNKWb2kJmtMrO/mNkkM4tb9n4By4OLrn0qWHeGmT0QXOd+dYUvCCdCXaULEKlSR5E98/RBM7sJ+EzQvsOzF/LLXd4hd+G8O4APuvvy4OJkfWTPrN3t7qeYWSPwoJn9jux1be519/8eXMageWx/NJEDKQhECtvs7g8Gj38KfC54fEeBdY8Ctrr7cgAPrhhpZn8FHJ/rNQBTyF6zZjlwU3CBsbvd/YmQfgaRsigIRAobPHmWe773IN7DgM+6+72vWZC9JPi7gZvN7DvufsvwyhQZOc0RiBQ2x8xODx7/LfDnIdZdD8wws1MAgvmBOuBe4O+Dv/wxsyODq8IeRvYmKD8Afkj29okiFaMgEClsPdmb+awDppK9wUpB7p4EPgj8bzNbBdwHNJH9kF8LPG7ZG6Z/n2wv/ExglZmtDF73ryH+HCIl6fBRkUGC2wT+u7sfW+FSRMaEegQiIhGnHoGISMSpRyAiEnEKAhGRiFMQiIhEnIJARCTiFAQiIhH3/wE0MzGjPo+4egAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import pylab\n",
    "import numpy as np\n",
    "def compute_delta(S):\n",
    "    inputs = torch.tensor([[110.0, 100.0, S, 0.35, 0.1, 0.05]]).cuda()\n",
    "    inputs.requires_grad = True\n",
    "    x = model(inputs)\n",
    "    x.backward()\n",
    "    first_order_gradient = inputs.grad\n",
    "    return first_order_gradient[0][2]\n",
    "prices = np.arange(10, 200, 0.1)\n",
    "deltas = []\n",
    "for p in prices:\n",
    "    deltas.append(compute_delta(p).item())\n",
    "fig = pylab.plot(prices, deltas)\n",
    "pylab.xlabel('prices')\n",
    "pylab.ylabel('Delta')\n",
    "fig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculating the second order derivative is easy in PyTorch too. We just need to apply the `grad` function twice. Following is an example to calculate the second order derivative for the same polynomial function as above:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([ 8., 24.])\n",
      "tensor([24., 18.])\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from torch.autograd import grad\n",
    "'''\n",
    "z = (xy)^2\n",
    "x = 3, y =2\n",
    "\n",
    "first order deriv [24 36]\n",
    "d2z/dx2 = 8\n",
    "d2z/dxdy = 24\n",
    "d2z/dy2 = 18\n",
    "'''\n",
    "\n",
    "inputs = torch.tensor([3.0,2.0], requires_grad=True)\n",
    "z = (inputs[0]*inputs[1])**2\n",
    "first_order_grad = grad(z, inputs, create_graph=True)\n",
    "second_order_grad_x, = grad(first_order_grad[0][0], inputs, retain_graph=True) #\n",
    "second_order_grad_y, = grad(first_order_grad[0][1], inputs)\n",
    "print(second_order_grad_x)\n",
    "print(second_order_grad_y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Use this mechanism, we can calculate the second order derivatives $\\frac{\\partial^2 P}{\\partial K \\partial S_0}$, $\\frac{\\partial^2 P}{\\partial B \\partial S_0}$, $\\frac{\\partial^2 P}{\\partial S_0^2}$, $\\frac{\\partial^2 P}{\\partial \\sigma \\partial S_0}$, $\\frac{\\partial^2 P}{\\partial \\mu \\partial S_0}$, $\\frac{\\partial^2 P}{\\partial r \\partial S_0}$ in the following example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.0143,  0.0039,  0.0098, -0.3183,  1.1455, -0.7876]],\n",
       "        device='cuda:0'),)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from torch import Tensor\n",
    "from torch.autograd import Variable\n",
    "from torch.autograd import grad\n",
    "from torch import nn\n",
    "\n",
    "inputs = torch.tensor([[110.0, 100.0, 120.0, 0.35, 0.1, 0.05]]).cuda()\n",
    "inputs.requires_grad = True\n",
    "x = model(inputs)\n",
    "\n",
    "# instead of using loss.backward(), use torch.autograd.grad() to compute gradients\n",
    "# https://pytorch.org/docs/stable/autograd.html#torch.autograd.grad\n",
    "loss_grads = grad(x, inputs, create_graph=True)\n",
    "drv = grad(loss_grads[0][0][2], inputs)\n",
    "drv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Gamma is the second order differenation of `S`. We can plot the the Gamma curve as a function of the stock price"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7f996066dba8>]"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEGCAYAAAB7DNKzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3de5RcZZnv8e9Tu6pv6dzIBUJCCGBEoyBig+iMiogO6BEc0SPM0UGPDt6YcYaZ48HFLMZh1vF4WeNRZzhqREYZZUA4o0aNouL9hgkQJNwkcpHEhNxI0p2+VFfVc/7Yu7qru6uqi6retbt3fp+1srpq1+6qJ7ur+1fv++733ebuiIiI1JJJugAREZndFBQiIlKXgkJEROpSUIiISF0KChERqSubdAEzbenSpb5mzZqkyxARmVPuvPPOve6+rNpjqQuKNWvWsHnz5qTLEBGZU8zs8VqPqetJRETqUlCIiEhdCgoREalLQSEiInUpKEREpC4FhYiI1KWgEBGRuhQUIjPgq3dvZ2CkkHQZIrFQUIi0aNvufv7m5nu44uYtSZciEgsFhUiLgkz4a7R1x8GEKxGJh4JCpEXlq0Tmi6WEKxGJh4JCpEWlKCh0VWFJq0SDwszOM7OHzGybmV1ZZ7+LzMzNrK+d9Yk0olBSQki6JRYUZhYA1wLnA+uAS8xsXZX95gPvA+5ob4UijSlGQaG4kLRKskVxJrDN3R9x9zxwE3Bhlf3+CfgIMNzO4kQaVYqGJlx9T5JSSQbFSuCJivvbo21jzOx04Dh3/1a9JzKzy8xss5lt3rNnz8xXKlJH0dWikHSbtYPZZpYBPg787XT7uvt6d+9z975ly6peoEkkNj/fthfQYLakV5JBsQM4ruL+qmhb2XzgucCPzOwx4Cxggwa0Zbb52G0PAep6kvRKMig2AWvN7AQz6wAuBjaUH3T3g+6+1N3XuPsa4FfABe6u65zKrKSYkLRKLCjcvQBcDtwGPAB8xd3vM7NrzOyCpOoSaVb/sNZ6knTKJvni7r4R2Dhp29U19j27HTWJiMhEs3YwW2SuOWXlwqRLEImFgkKkRb2dYcN8XmeQcCUi8VBQiLSgWPKx61AUtZSHpJSCQqQFlRcr0ppPklYKCpEWPLF/cOy2WhSSVgoKkRYcHBoFYGF3jtGigkLSSUEh0oJ8IVwRsKcjoFjShYsknRQUIi0YiYKiuyPQGIWkloJCpAUjhSJQblEoKCSdFBQiLRjresplKWiMQlJKQSHSgnwxCopOtSgkvRQUIi2oHMzWGIWklYJCpAXloOjOZXXWk6SWgkKkBeWznuZ1BhqjkNRSUIi0IF8oYQad2Yy6niS1FBQiLcgXS3RmM2SDjAazJbUUFCItyBdKdAQZshmjoDEKSSkFhUgLRgolOrIBQcYoOZTUqpAUUlCItGCkUAy7njIGaKlxSScFhUgL8oVwjCLIhL9KJVdQSPooKERakC+U6FCLQlJOQSHSguEoKIIoKIqaSyEppKAQacHvdg9w/JJ5ZINyi0JnPkn6KChEWnBwaJRlvZ1kLGpRaIxCUijRoDCz88zsITPbZmZXVnn8XWZ2r5ltMbOfmdm6JOoUqaZYcgZGCvR2ZcfGKDTpTtIosaAwswC4FjgfWAdcUiUIbnT3U9z9NOCjwMfbXKZITUOj4UWLejuDsTEKrfckaZRki+JMYJu7P+LueeAm4MLKHdz9UMXdeYB+C2XWGMqHQdGVC8bGKNSikDTKJvjaK4EnKu5vB144eSczey9wBdABnFPticzsMuAygNWrV894oSLVDEctiq5sMDaPQqfHShrN+sFsd7/W3U8C/ifw9zX2We/ufe7et2zZsvYWKEes8vWyuzoCjVFIqiUZFDuA4yrur4q21XIT8LpYKxJ5GoZHw1NhuyrmUej0WEmjJINiE7DWzE4wsw7gYmBD5Q5mtrbi7muAh9tYn0hdY11PufEWhXJC0iixMQp3L5jZ5cBtQABc7+73mdk1wGZ33wBcbmbnAqPAU8ClSdUrMtnegREgDIryGk9qUUgaJTmYjbtvBDZO2nZ1xe33tb0okQa960t3AdCRzTBS0BiFpNesH8wWme2GR4sVYxQKCkkfBYVIk555dC8AzzpmPtno9Fi1KCSNEu16EpnLXnD8Uew/PMqing61KCTV1KIQadJoMbxoEVBx1pOCQtJHQSHSpNFiiVy0dIdaFJJmCgqRJoVBEbUoxtZ60umxkj4KCpEm5Qs+HhRRi2JUq8dKCikoRJqUL5bIRWMU5cAYLapFIemjoBBp0mihREfU5dSdC4Dxa1SIpImCQqRJlWMUneWgyCsoJH0UFCJNGi2W6Ii6nsotimG1KCSFFBQiTcoXxwezc4FhBiMFjVFI+igoRJqULxTpiILCzMiYja0iK5ImCgqRJh0cKrCge3wVnMAMnfQkaaSgEGlCqeQ8NZjnqHkdY9syGdSikFRSUIg0oX+4QLHkLO6pCAozrfUkqaSgEGnC/sE8wIQWRWBGUS0KSSEFhUgTyvMlejqCsW2ZjFoUkk4KCpEm5KNR6/I8CoCMgXJC0khBIdKE8ppO5XkUEC41rq4nSSMFhUgT8tHEuo6gskWhridJJwWFSBPKXU+5iq6n3f0jbHniQFIlicRGQSHShGotCoAHd/UnUY5IrBQUIk0oB0VnVr9Ckn6JvsvN7Dwze8jMtpnZlVUev8LM7jez35jZ7WZ2fBJ1ikxWbTBbJK0Se5ebWQBcC5wPrAMuMbN1k3a7G+hz91OBW4GPtrdKkerGup6qtChcZz5JyiT5cehMYJu7P+LueeAm4MLKHdz9h+4+GN39FbCqzTWKVFWvRVHUmU+SMkkGxUrgiYr726Nttbwd+Ha1B8zsMjPbbGab9+zZM4MlilQ3UqVFcdHp4ecYzaWQtJkTHaxm9magD/hYtcfdfb2797l737Jly9pbnByRyqfHVg5mP2N5L6AWhaRPdvpdYrMDOK7i/qpo2wRmdi5wFfAydx9pU20idY2MTu16ymYMUFBI+iTZotgErDWzE8ysA7gY2FC5g5k9H/gscIG7706gRpGqhkaLdOcCgigcIFwUEKCkixdJyiQWFO5eAC4HbgMeAL7i7veZ2TVmdkG028eAXuAWM9tiZhtqPJ1IWx0eKUxYORYgiDJDYxSSNkl2PeHuG4GNk7ZdXXH73LYXJdKA/uECPZ2TgkJdT5JSc2IwW2S2eWTvACcs7Z2wbazrSS0KSRkFhUgTRkZL9E5uUZhaFJJOCgqRJuSLpSkLAqrrSdJKQSHShNFCacqsbAWFpFXDg9lm9lzCNZm6ytvc/YY4ihKZ7fLF0oRrUUBFUGiMQlKmoaAws38AziYMio2EC/n9DFBQyBEpX5ja9ZSx8jwKBYWkS6NdT28AXgHscve3Ac8DFsZWlcgsN1r0KSvHqkUhadVoUAy5ewkomNkCYDcTl98QOaJUG8zO6KwnSalGxyg2m9ki4HPAncAA8MvYqhKZxYolp1jyKYPZWutJ0qqhoHD390Q3P2Nm3wEWuPtv4itLZPYauxZF1iZs11lPklZP56ynU4E15e8xs2e4+3/GVJfIrFVeYnxK15NmZktKNXrW0/XAqcB9QHltTAcUFHLEGa1xGdTxmdltL0kkVo22KM5y98nXsxY5ItVuUYRf1fUkadPoWU+/NDMFhQgwWgiDYMrMbFPXk6RToy2KGwjDYhcwAhjg7n5qbJWJzFL5YhGY2vWUjS5IUVCLQlKm0aD4PPAW4F7GxyhEjkj5Gi0KzcyWtGo0KPa4u64uJ8L46bEdOj1WjhCNBsXdZnYj8A3CricAdHqsHInGB7OrX+GuoItmS8o0GhTdhAHxqoptOj1Wjkjl02NzwcQWRWc2DI6RgoJC0qXRmdlvi7sQkblipFh9HkVndD+voJCUaXTC3QnAX1IxMxvA3S+IpyyR2Wu8RTExKMrBkdeMO0mZRruevkZ45tM30FlPcoQbLYaD1bVaFCOj+hWRdGk0KIbd/VOxViIyR4zNo1CLQo4QjQbFJ6Or3H2XiWc93RVLVSKz2NjM7MkT7qI1PHR6rKRNo0FxCuGEu3OYuCjgOa28uJmdB3wSCIDr3P3Dkx5/KfAJwgUJL3b3W1t5PZGZMFJjrafy9SgKRQWFpEujaz29ETjR3V/m7i+P/rUaEgFwLeH1t9cBl1RZT+r3wFuBG1t5LZGZNLZ6bI1lxv/P93/b9ppE4tRoUGwFFs3wa58JbHP3R9w9D9wEXFi5g7s/Fl0gSZ2+MmscHikA0N0R1NxHy3hImjTa9bQIeNDMNjFxjKKV02NXAk9U3N8OvLCZJzKzy4DLAFavXt1CSSLTOzg0SncumHLWU6XD+QLzu3JtrEokPo0GxT/EWkWL3H09sB6gr69PH+UkVvsP51nUUz8EDg0rKCQ9Gp2Z/eMYXnsHcFzF/VXRNpFZbfuBIVYt7q67T//wKOHKNyJzX0NjFGZ2lpltMrMBM8ubWdHMDrX42puAtWZ2gpl1ABcDWqFWZr2Dg6McNa+j7j6HhgptqkYkfo0OZv8rcAnwMOHHpHcQnrHUNHcvAJcDtwEPAF9x9/vM7BozuwDAzM4ws+2EZ1191szua+U1RWbCoeHRabuVhkeLbapGJH6NjlHg7tvMLHD3IvBvZnY38IFWXtzdNwIbJ227uuL2JsIuKZFZo3+4wIJpgkKXQ5U0aTQoBqPuoXvM7KPAThpvjYikRqFYYmCkwILu+r86yglJk0b/2L8l2ve9wGHCT/kXxVWUyGw1EM2hqNX1dNlLTwTUopB0qRsUZnahmb3X3R9392Hge4Qzpf8UOK0N9YnMKvsO5wFYUmMw+4LnHQuA5ttJmkzXong/E89E6gReAJwNvDummkRmrT394XzTpb2dVR+36KJ3alFImkw3RtHh7pWzp3/m7vuB/WY2L8a6RGalJw8NA3DMwq6qj2eipHAFhaTIdC2KxZV33P3yirvLZr4ckdlt58HGgkJdT5Im0wXFHWb2F5M3mtk7gV/HU5LI7LXr4DDzO7P0dlZvjGfU9SQpNF3X098AXzOzPwPKFyl6AeFYxeviLExkNtp1cLhmawLA1KKQFKobFO6+G3ixmZ0DPCfa/C13/0HslYnMQjsP1Q+KcotCYxSSJo0uCvgDQOEgR7ydB4Z45jNrD8+Nj1EoKCQ9NLtapEGD+QK7+0dYs7T2CX9jQaFLbUmKKChEGvTY3kEAjl/SU3Of8jyKoloUkiIKCpEGPb7vMABrltRpUWQ0j0LSR0Eh0qBHy0FRp+sp0FlPkkIKCpEG7RvI09MR1JxDAZpHIemkoBBp0GC+SE9HUHcfzaOQNFJQiDRoeLRI9zRBoXkUkkYKCpEG9Q8X6MnVn3o0fnqsgkLSQ0Eh0gB3Z8sTB3jWivl199OigJJGCgqRBvxuzwB7B0Z48UlL6u5n0W+UBrMlTRQUIg14cFc/AKeuWlR3v/HrUcRekkjbKChEGjA8Gq7JUe/UWNDpsZJOCgqRBowUigB0Zuv/yrQ6RjHd2VKD+cKE+8WSMzBSqLG3yMxoaPVYkSPdSNSi6MxON48i/NpMi+Kpw3le/s8/4sDgKC8/eRmffvML6MoFuDsf/95v+ZcfbBvb90tvfyG7+4e54iv3AHDacYs47biwW+w32w9w/JJ5bNs9QHdHQLHkXH7OM1g6r5M1S3vIZjKYhf+nhT25p12nHHkSDQozOw/4JBAA17n7hyc93gncQHixpH3Am9z9sXbXKTJSiIIi12CL4mk2Ke76/VP844b7ODA4CsAPH9rDV+/ewSVnrubOx5+aEBIAb/78HRPub3niAFueOFDxfAcmPP62f9tU9XWzGWP1kh6Wz+8kF2RYfVQPI4USC7py7D88wu7+cAB/3bEL6MoFLOjK8ZxjF4xNLJQjQ2JBYWYBcC3wSmA7sMnMNrj7/RW7vR14yt2fYWYXAx8B3tT+auVI1j88yte37GBpbycdQf2gaHatp9f/31+M3f7p+1/ORZ/+BXc9/hQXPO9YPviN+wB481mreedLT+Lqr2/lhw/tAeAXV55Ddy4gGxj9wwX2H87z+/2DnLJyIUOjRVYt7mbjvbu4edPvOe6oHno7s3R3BGzdcZCOIMPu/hEe3zc4Nvj+04f3TqntF7/bN+H+/K4sS3s7Wbu8l+FCiYzBnY8/RW9nllefsoJckGEwX+D4JfNYuaiLtUfPZ35nllyQ4Z7tB1gyr5NFPTlyQYalvR1kpzmmkrwkWxRnAtvc/REAM7sJuBCoDIoLgQ9Gt28F/tXMzGOY9losObv7hydsq/YqkzdVK6WR6ibv41OeuYXXb6gen3afat/WTN3Vn7t9x636/2Pq1sF8kd39w+w+NEJHNsPa5fO57w8HueGXj7Pr0DDXv/WMsdVha2ml66msuyNg+YJObrlzO7fcuR2AF524hGsueC6ZjPGh15/Cm6+7g79/zTqOXdQ99n3zu3Icu6ib565cOOH53vCCVbzhBasaeu2DQ6PsHRhhzZJ5DI8WcWBv/wh7B0YYGi2y/akh7v/DIbbtHmDb7gE6cwG5wOjpCBgYLvD5nz1KLrAoLIrTvl4uMDqzARmDBd05OrIZAjPyxRJBxli5qJsFXeF29/Cn7R62hJYt6CSbMQIzzIzOXIZ5HdnwhAIzMha28DIGQSZDtBkqv2Is6M6OPZ4xw4xo3/HbmYyNfX+QyZDNlB8zMpnoq41/T/l1jWg/q3zu+FpjucBY0ts548+bZFCsBJ6ouL8deGGtfdy9YGYHgSXA1I89LTowmOdF/1sX8ZOp1i7v5ZZ3vYjTVy+edt/yH4pGPsscGMyTDTJTzqTqzgW8ZO0ytu44BIR/FG/8ixeOdfesWNjN7X979tP/jzRgYXeOhd3huMW8qK7ezmzdFXMrFYolskGGUsnZOzDCQ0/28+ShEQ4M5gkyxuKeDgD6RwoUiyV2HBii5OH3HRoukC+WKJWcXJBhd/8w/cMFtj81RKFUGv+jC+QLJfYezlMsOSV3nY4cOe24RXztvX8048+bisFsM7sMuAxg9erVTT3HvM4sH379KVWeu8rrTf5EUHWfKTU2sE+V52ng9RvpLm7k9as9V7VPP1P3mf55qu1V/f82eZ+ZOW5VP8VN2tSZzbB8fhcZg0/d/jDHLurm1FULeeW6YwimaUlUypg11PV02jXf48Rl87j9ipdN2N6dC/i7V53MknkdzO/K8rrnr5wzYwLlbqRMxli+oIvlC2pfX3wmuTtDo0WG8mErqOQOHnYBFt0pFn2s9VkOlfJ+B4dGw9aKR9tK4y2XylZMKbpdLJUoFMvbPXps/PHxbU7JqXhuj33G/lHzOmJ53iSDYgdwXMX9VdG2avtsN7MssJBwUHsCd18PrAfo6+tr6kfRlQu4+MzmQkbS5xMXP7/p783Y9F1P5dNcH9lzmHxx/Lqp5z57+Vj31jtecmLTNRxpzIyejiw9Han47DvrJDmKtAlYa2YnmFkHcDGwYdI+G4BLo9tvAH4Qx/iEyEyyBloUv31yYOz2aDHc+apXP5vrLj0jztJEmpJY/EZjDpcDtxGeHnu9u99nZtcAm919A/B54N/NbBuwnzBMRGa1TANjFL9+NGwYr1nSw84DQ8D4pD6R2SbRdpq7bwQ2Ttp2dcXtYeCN7a5LpBXhGEX9oPjDgfAMu6W9nfzHr8NzOlbXuRa3SJJ0ArPIDGtkMLt8Kva+w3mu//mjALz21BVxlybSFAWFyAyzaQaz84US39m6C4BH9x4GIMjYnDmzSY48CgqRGZYxq3te/6N7D09pcfz4f5wda00irVBQiMyw6U6P/UM0eH3swvE5BisrZliLzDY66VhkhmXMKNYYpPjQxgfGup0W9XTwh4PDvPflJ6nbSWY1tShEZliQqX7W02ixxPqfPMLv9w8CsHJx2Io451nL21qfyNOlFoXIDAsyRqE4NSiGR8fnSSzqyXHl+c/ijDWLG1pDSiRJCgqRGRZkjGKVFkX5mhYAy+d3ctKyXk5a1tvO0kSaoq4nkRkWZKqPUVS2KJbPb89ieSIzQUEhMsOCGoPZE1oUC2b+mgEicVFQiMywWi2K8nW3AdatWNDOkkRaoqAQmWE1g6Ji0b+j23SdBpGZoMFskRn24K5+HtzVP2X7cEWLQkEhc4laFCJtUtmieNaK+QlWIvL0KChEYjJaceU6gINDowB866/+mAVduSRKEmmKgkIkJgPDhQn37995CIDFPfFc11gkLgoKkZhMblEY4XpOx2oBQJljFBQiM+wjF50CQL5YIl8ocefjT4X3CyV6O3X+iMw9CgqRGdaRDX+t9g7k+adv3s9Fn/4F23YPkC8WyQVaJVbmHn28EZlhj+4Jr1p3+Y13sWx+OAP74FCex/cNMs0VUkVmJbUoRGZadG2J7U8NVVzJzvjpw3s5MDiaWFkizVJQiMywCYPY0SqyhUkD2yJziYJCZIa97Y/WAPDa5x071qIYqlg5VmSuUVCIzLDl87tYMq+DBV3ZsSvdlZcYf/fZJyVZmkhTFBQiMchEl0MtX7+ovM6TVo2VuSiRoDCzo8zse2b2cPS16rUgzew7ZnbAzL7Z7hpFWpGNVpCd3KLo6QiSLEukKUm1KK4Ebnf3tcDt0f1qPga8pW1VicyQjBmF0niLYjBfDgqdkS5zT1JBcSHwxej2F4HXVdvJ3W8Hpq7XLDLLZQOjVHI8mjlxeCRc90ktCpmLkgqKo919Z3R7F3B0K09mZpeZ2WYz27xnz57WqxNpURC1KMpnPQ0oKGQOi60dbGbfB46p8tBVlXfc3c2spQmr7r4eWA/Q19enya+SuCAazC6PUYwFhdZ6kjkotnetu59b6zEze9LMVrj7TjNbAeyOqw6RJAQZo1B0ymt2fPmO3wPQk1OLQuaepLqeNgCXRrcvBb6eUB0isZjcoiib36UWhcw9SQXFh4FXmtnDwLnRfcysz8yuK+9kZj8FbgFeYWbbzexPEqlW5GkKMhPHKADOf+4xZANNXZK5J5GPN+6+D3hFle2bgXdU3H9JO+sSmSkZC+dReMV6sd/euivBikSap483IjEYm3BXsRbg+887ObmCRFqgDlORGGQyxqbH9jNaHG9RvPCEoxKsSKR5alGIxCCbsQkhAbC4pyOhakRao6AQiUGQmXrJ06MXdCVQiUjrFBQiMfjpw3sn3H/zWauZp8l2MkcpKETa4APnPzvpEkSapqAQiVmQMbUmZE5TUIjELFtlvEJkLlFQiMQsp9nYMsfpHSwSs1ygFoXMbQoKkRhUhoNaFDLX6R0sEoNb3/XisdsKCpnr9A4WiUFXxXUn1PUkc52CQiQGWXU9SYroHSwSg4XdubHbCgqZ6/QOFonBosqgyOrXTOY2vYNFYpANMrzn7JMAyGnCncxxCgqRmPStWQxAr66TLXOcgkIkJsOj4eXtlszrTLgSkdboo45ITF657mje+bITec/LnpF0KSItUVCIxCQXZLS8uKSCup5ERKQuBYWIiNSloBARkboSCQozO8rMvmdmD0dfF1fZ5zQz+6WZ3WdmvzGzNyVRq4jIkS6pFsWVwO3uvha4Pbo/2SDw5+7+HOA84BNmtqiNNYqICMkFxYXAF6PbXwReN3kHd/+tuz8c3f4DsBtY1rYKRUQESC4ojnb3ndHtXcDR9XY2szOBDuB3NR6/zMw2m9nmPXv2zGylIiJHuNjmUZjZ94Fjqjx0VeUdd3cz8zrPswL4d+BSdy9V28fd1wPrAfr6+mo+l4iIPH3m3v6/q2b2EHC2u++MguBH7n5ylf0WAD8CPuTutzb43HuAx2ey3hm2FNibdBF1qL7WqL7WqL7WtFLf8e5etXs/qZnZG4BLgQ9HX78+eQcz6wC+CtzQaEgA1PqPzhZmttnd+5KuoxbV1xrV1xrV15q46ktqjOLDwCvN7GHg3Og+ZtZnZtdF+/xX4KXAW81sS/TvtGTKFRE5ciXSonD3fcArqmzfDLwjuv0l4EttLk1ERCbRzOz2W590AdNQfa1Rfa1Rfa2Jpb5EBrNFRGTuUItCRETqUlCIiEhdCoqYmNlxZvZDM7s/WtjwfdH2D5rZjoozuV6dYI2Pmdm9UR2bo23TLtjYptpOrjhGW8zskJn9ddLHz8yuN7PdZra1YlvVY2ahT5nZtmhhy9MTqu9jZvZgVMNXy2ummdkaMxuqOJafSai+mj9TM/tAdPweMrM/Sai+mytqe8zMtkTbkzh+tf6uxPsedHf9i+EfsAI4Pbo9H/gtsA74IPB3SdcX1fUYsHTSto8CV0a3rwQ+MgvqDAiXejk+6eNHeMr26cDW6Y4Z8Grg24ABZwF3JFTfq4BsdPsjFfWtqdwvweNX9Wca/b7cA3QCJxAu4RO0u75Jj/8zcHWCx6/W35VY34NqUcTE3Xe6+13R7X7gAWBlslU1ZNoFGxPwCuB37p74jHt3/wmwf9LmWsfsQsIJo+7uvwIWRSsRtLU+d/+uuxeiu78CVsVZQz01jl8tFwI3ufuIuz8KbAPOjK046tdnZkY4v+s/4qyhnjp/V2J9Dyoo2sDM1gDPB+6INl0eNQOvT6prJ+LAd83sTjO7LNr2tBZsbJOLmfjLOVuOX1mtY7YSeKJiv+0k/2HhvxN+wiw7wczuNrMfm9lLkiqK6j/T2Xb8XgI86dGq1pHEjt+kvyuxvgcVFDEzs17g/wF/7e6HgE8DJwGnATsJm7JJ+WN3Px04H3ivmb208kEP266Jnj9t4VIuFwC3RJtm0/GbYjYcs1rM7CqgAHw52rQTWO3uzweuAG60cH21dpvVP9MKlzDxA0tix6/K35UxcbwHFRQxMrMc4Q/zy+7+nwDu/qS7Fz1cCfdzxNyUrsfdd0RfdxOuq3Um8GS5aRp93Z1UfZHzgbvc/UmYXcevQq1jtgM4rmK/VdG2tjOztwL/Bfhv0R8Soi6dfdHtOwnHAJ7Z7trq/Exn0/HLAq8Hbi5vS+r4Vfu7QszvQQVFTKL+zM8DD7j7xyu2V/YP/imwdfL3toOZzTOz+eXbhAOeWxlfsBFqLNjYZhM+xc2W4zdJrWO2Afjz6MyTs4CDFd0DbWNm5wHvBy5w98GK7cvMLIhunwisBR5JoL5aP9MNwMVm1mlmJ0T1/brd9UXOBR509+3lDUkcv1p/V/25dxgAAAJ3SURBVIj7PdjOEfsj6R/wx4TNv98AW6J/rya8tsa90fYNwIqE6juR8IySe4D7gKui7UsIL0/7MPB94KgEj+E8YB+wsGJbosePMLR2AqOE/b1vr3XMCM80uZbwk+a9QF9C9W0j7Kcuvw8/E+17UfSz3wLcBbw2ofpq/kwJr1/zO+Ah4Pwk6ou2fwF416R9kzh+tf6uxPoe1BIeIiJSl7qeRESkLgWFiIjUpaAQEZG6FBQiIlKXgkJEROpSUIjEyMyuMbNzk65DpBU6PVYkJmYWuHsx6TpEWqUWhUgTomsRPGhmXzazB8zsVjPria5X8BEzuwt4o5l9wczeEH3PGWb2CzO7x8x+bWbzzSyw8HoRm6JF8d4Z7bvCzH4SXedga8IL9skRLpt0ASJz2MmEM3d/bmbXA++Jtu/zcLHF8vIZ5cUNbwbe5O6bosXjhghnJh909zPMrBP4uZl9l3Bdodvc/X9Fy0T0tPe/JjJOQSHSvCfc/efR7S8BfxXdvrnKvicDO919E4BHK36a2auAU8utDmAh4ZpBm4DrowXgvubuW2L6P4hMS0Eh0rzJA3zl+4efxnMY8JfuftuUB8Jl318DfMHMPu7uNzRXpkhrNEYh0rzVZvai6PafAT+rs+9DwAozOwMgGp/IArcB745aDpjZM6OVfY8nvEjO54DrCC/PKZIIBYVI8x4ivODTA8BiwgvwVOXueeBNwL+Y2T3A94AuwhC4H7jLzLYCnyVs6Z8N3GNmd0ff98kY/x8iden0WJEmRJeh/Ka7PzfhUkRipxaFiIjUpRaFiIjUpRaFiIjUpaAQEZG6FBQiIlKXgkJEROpSUIiISF3/HzeFUqWgBnYUAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pylab\n",
    "import numpy as np\n",
    "def compute_gamma(S):\n",
    "    inputs = torch.tensor([[110.0, 100.0, S, 0.35, 0.1, 0.05]]).cuda()\n",
    "    inputs.requires_grad = True\n",
    "    x = model(inputs)\n",
    "    loss_grads = grad(x, inputs, create_graph=True)\n",
    "    drv = grad(loss_grads[0][0][2], inputs)\n",
    "    return drv[0][0][2]\n",
    "\n",
    "prices = np.arange(10, 200, 0.1)\n",
    "deltas = []\n",
    "for p in prices:\n",
    "    deltas.append(compute_gamma(p).item())\n",
    "fig2 = pylab.plot(prices, deltas)\n",
    "pylab.xlabel('prices')\n",
    "pylab.ylabel('Gamma')\n",
    "fig2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Implied volatility](https://en.wikipedia.org/wiki/Implied_volatility) is the forecasted volatility of the underlying asset based on the quoted prices of the option. It is the reverse mapping of price to the option parameter given the model which is hard to do with the Monte Carlo simulation approach. But if we have the deep learning pricing model, it is an easy task. We can first plot the relationship between volatility and the option price"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<matplotlib.lines.Line2D at 0x7fa5e841c2b0>]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAEGCAYAAABiq/5QAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3dd3xV9f3H8dcnhD3C3oSACDIFkjAcpY7WUaq2uBgKEkBbq/3po1at1mpbfx3WWq22yogBkeFs8VdttYpFQEISVlgykkDYO4wQyPj+/rgXGyKRBHLvueP9fDzy4Nxzz8l558J9c/K933uuOecQEZHoEeN1ABERCS4Vv4hIlFHxi4hEGRW/iEiUUfGLiESZWK8DVEXLli1dQkKC1zFERMJKVlbWPudcq4rrw6L4ExISyMzM9DqGiEhYMbMtZ1qvoR4RkSij4hcRiTIqfhGRKKPiFxGJMip+EZEoo+IXEYkyKn4RkSij4hcRCUHHT5by5Lw1HCo8WePfW8UvIhJiysocD76xgumf57E8/1CNf38Vv4hIiHnmwy/4YPUuHru+J1f0aF3j31/FLyISQt7IyOevn25m9OB4Ui7rEpBjqPhFRELEok37+Nm72Vx+YUuevKE3ZhaQ46j4RURCwKY9R7hnZhZdWzXkpdEDqV0rcPWs4hcR8dj+oye4Ky2DurExpI5Lpkm92gE9XlhclllEJFIVFZcy6bUs9hw+wZxJQ+jYrEHAj6niFxHxiHOOh95aRdaWg/xl9EAGxDcLynE11CMi4pHnPtrAeyt38PC1F3F933ZBO66KX0TEA29nbeOFTzZxW1In7hnWNajHVvGLiARZes5+HnlnFZdc0IJff69PwKZtVkbFLyISRLn7jnH3zCzimzfgr6MTAzptszIqfhGRIDl47CR3vbqUGDNeHTeIuAaBnbZZGc3qEREJghMlpdz9WhY7CoqYPXEw8S0CP22zMjrjFxEJMOccj76dzdK8Azxzcz8SOzf3NI+KX0QkwP78ySbeWb6dB7/VnRv7d/A6jopfRCSQ/r5iO3/8aAPfH9CB+67s5nUcQMUvIhIwWVsO8NBbqxjUpTm/GdE36NM2K6PiFxEJgK37C5k4I4sOTevzyphE6sbW8jrSl1T8IiI1rKCwmLvSllLmHKnjkmnWsI7XkU6j4hcRqUEnS8r4wetZbD1QyCtjEunSsqHXkb5C8/hFRGqIc46f/201izfv59lbLmZw1xZeRzojnfGLiNSQl/+Tw9zMfO6/shsjEjt6HadSKn4RkRrwfvZOfvfP9Xz34vY88K3uXsf5Wip+EZHztCL/EA/MXUFi52Y8c3O/kJm2WRkVv4jIedh2sJAJ0zNp3aQuk+9IpF7t0Jm2WRm9uCsico4OFxUzPi2DEyWlzJk0mBaN6nodqUp0xi8icg5KSsu49/Vl5Ow9xitjEunWurHXkapMZ/wiItXknOMX89bw2cZ9/G5EXy7p1tLrSNUSsDN+M0s1sz1mtrrcuovN7HMzyzaz98ysSaCOLyISKNMW5vJ6+lbuGXYBtyXHex2n2gI51JMGXFth3VTgEedcX+Bd4KEAHl9EpMZ9uGYXT7+/juv7tuWn1/TwOs45CVjxO+cWAAcqrO4OLPAvfwSMCNTxRURqWva2An48ZwX9Ojblj7f2JyYmtKdtVibYL+6uAW70L98CdKpsQzObZGaZZpa5d+/eoIQTEanMjkPHSZmeQfOGdZhyZ3hM26xMsIt/PPBDM8sCGgMnK9vQOTfZOZfknEtq1apV0AKKiFR09EQJKdMzOX6ylNRxybRuXM/rSOclqLN6nHPrgW8DmFl34DvBPL6ISHWVlJZx/+zlbNh9hNRxyfRoGz7TNisT1DN+M2vt/zMGeBx4OZjHFxGprl//Yx2frN/DUzf0Zlj3yBh9COR0ztnA50APM9tmZinASDPbAKwHdgCvBur4IiLna/riPNIW5zHhsi6MGdLZ6zg1JmBDPc65kZXc9XygjikiUlM+Wb+bp95bw9U92/Do9T29jlOjdMkGEZEK1u44zH2zltOrfRNeGNmfWmE6bbMyKn4RkXJ2Hy4iZXoGjevVZtrYZBrUibwr20TeTyQico4KT5aQMj2DguPFvHnPUNo0Ce9pm5XRGb+ICFBa5vjxnBWs3XGYF0cNoHf7OK8jBYyKX0QE+O0H6/ho7W6eGN6LKy9q43WcgFLxi0jUez19C1M+y2Xs0M6Mu7SL13ECTsUvIlFtwYa9PPH3NVzRoxU/H97L6zhBoeIXkaj1xa4j3Pv6Mi5s3Yg/jxpIbK3oqMTo+ClFRCrYe+QE49MyqF+nFqnjkmlUN3omOUbPTyoi4ldUXMqEGZkcOHaSN+4eSvum9b2OFFQqfhGJKmVljgffWMGqbYd4eUwifTtG7rTNymioR0SiyjMffsH72bv42XU9uaZ3W6/jeELFLyJR442MfP766WZGDY5nwuWRP22zMip+EYkKizft42fvZnP5hS156obemEXWhdeqQ8UvIhFv056j3DMzi66tGvLS6IHUjpJpm5WJ7p9eRCLe/qO+aZt1YmOYNjaZJvVqex3Jc5rVIyIRq6i4lEmvZbH7cBFzJg2hU/MGXkcKCSp+EYlIzjl++tYqsrYc5C+jBzIgvpnXkUKGhnpEJCI99++NzFu5g59e24Pr+7bzOk5IUfGLSMR5Z9k2Xvh4I7cmdeQHwy7wOk7IUfGLSERJz9nPw2+vYmjXFvz6pr5RPW2zMip+EYkYufuOcffMLDo1b8DLYxKpE6uKOxM9KiISEQ4eO8n4tAxizHh1XDJxDTRtszIqfhEJeydLyrh7ZhbbDx5n8h2JdG7R0OtIIU3TOUUkrDnneOSdVSzNPcDzt/cnKaG515FCns74RSSsvfjJJt5Ztp0Hru7Ojf07eB0nLKj4RSRszVu5g2c/2sD3B3Tg/qu6eR0nbKj4RSQsZW05wE/eXMmghOb8ZoSmbVaHil9Ews7W/YVMnJFF+7h6vHJHInVja3kdKayo+EUkrBQcL+autKWUOUfquGSaNazjdaSwo+IXkbBRXFrGD1/PYuuBQl4ek0jXVo28jhSWNJ1TRMKCc47H313Nok37efaWixnStYXXkcKWzvhFJCy8siCHuZn53HdlN0YkdvQ6TlhT8YtIyPsgeye//WA9w/u144Gru3sdJ+yp+EUkpK3IP8T/zF3BwPim/OGWi4mJ0bTN86XiF5GQte1gIROmZ9K6SV2m3JlEvdqatlkT9OKuiISkw0XFpKRlcqKklDmTBtOiUV2vI0UMnfGLSMgpKS3jR7OWs3nvUV4ek0i31o29jhRRAlb8ZpZqZnvMbHW5df3NbImZrTCzTDMbFKjji0h4cs7x5HtrWLBhL7++qQ+XdmvpdaSIE8gz/jTg2grrfg885ZzrDzzhvy0i8qVpC3OZuWQrdw/ryu2D4r2OE5ECVvzOuQXAgYqrgSb+5ThgR6COLyLh56O1u3n6/XVc16ctD19zkddxIlawX9z9H+BfZvYHfP/pXFLZhmY2CZgEEB+v//VFIt3q7QXcP3s5/TrE8cdb+2vaZgAF+8XdHwAPOOc6AQ8A0yrb0Dk32TmX5JxLatWqVdACikjw7Sw4Tsr0DJo3rMOUsUnUr6Npm4EU7OIfC7zjX34T0Iu7IlHu2IkSxqdlcuxEKanjkmnduJ7XkSJesIt/BzDMv3wlsDHIxxeREFJa5rh/9nI27D7CS6MH0qOtpm0GQ8DG+M1sNvBNoKWZbQN+AUwEnjezWKAI/xi+iESnX/3fWj5ev4df3dSHYd01pBssASt+59zISu5KDNQxRSR8TF+cR9riPFIu68IdQzp7HSeq6J27IhJ089fv4an31nB1zzb87PqeXseJOip+EQmqdTsP86NZy+jZrgnP396fWpq2GXQqfhEJmj2Hi0hJy6BxvdpMG5tMw7q6TqQX9KiLSFAUniwhZXomh44X8+Y9Q2kbp2mbXqnWGb+ZNQhUEBGJXEdPlDAuNYM1Owr488gB9G4f53WkqFal4jezS8xsLbDef/tiM/tLQJOJSEQ4XFTMndPSydp6kBdGDuCqnm28jhT1qnrG/xxwDbAfwDm3EvhGoEKJSGQoKCzmjqnpZG8v4KVRAxner73XkYRqjPE75/LNTnv1vbTm44hIpDh47CRjpqWzcbfvw1R0ph86qlr8+WZ2CeDMrDbwY2Bd4GKJSDjbd/QEY6amk7vvGFPGJulduSGmqkM99wD3Ah2A7UB//20RkdPsOVLEyMlLyNt/jNRxySr9EFSlM37n3D5gdICziEiY21VQxKgpS9h1uIi0uwYxpGsLryPJGVR1Vs90M2ta7nYzM0sNXCwRCTfbDx3ntsmfs+fICWaMV+mHsqqO8fdzzh06dcM5d9DMBgQok4iEmfwDhYycsoSC48W8ljKIAfHNvI4kX6OqY/wxZvbl36SZNUfv+hURIG/fMW575XOOFJUwa8IQlX4YqGp5Pwt8bmZvAgbcDDwdsFQiEhY27z3KqClLKC51zJ44hF7tm3gdSaqgqi/uzjCzTHyfmgXwfefc2sDFEpFQt3H3EUZOSQd8pa9PzwofX1v8ZtbEOXfYP7SzC5hV7r7mzrkDgQ4oIqFn3c7DjJmaTq0YY9bEoXRr3cjrSFINZzvjnwUMB7IAV269+W93DVAuEQlRq7cXMGZaOvVr12LWxCF0adnQ60hSTV9b/M654ea7TsMw59zWIGUSkRC1Mv8Qd0xLp3G92syeOIT4Frpgbzg666we55wD/hGELCISwrK2HGTM1HSaNqjD3LtV+uGsqtM5l5lZckCTiEjIWpp7gDunpdOycV3m3j2Ejs1U+uGsqtM5BwNjzCwPOIZ/jN851y9QwUQkNCzevI+UtEzaN63H7IlDaN1En5wV7qpa/NcENIWIhKQFG/YycUYmCS0aMnPCYFo1rut1JKkBZ5vOWQ/flTm7AdnANOdcSTCCiYi35q/fw90zs+jWqhEzJwymecM6XkeSGnK2Mf7pQBK+0r8O3zt4RSTCfbhmF5Ney6RHm8bMmqjSjzRnG+rp5ZzrC2Bm04ClgY8kIl56P3sn989eTp8OcUwfP4i4+rW9jiQ17Gxn/MWnFjTEIxL5/r5iO/fNXk7/Tk15LUWlH6nOdsZ/sZkd9i8bUN9/+9SsHl2RSSRCvJ21jYfeWklyQnNSxyXTsK4uwBupzvbO3VrBCiIi3pmbsZVH3snm0gtaMuXOJOrX0VM/klX1DVwiEqFmLtnCw29n840LWzF1rEo/Guh3OZEo9uqiXJ56by1X92zNS6MHUjdWpR8NVPwiUWrKghyefn8d1/ZuywsjB1AnVgMA0ULFLxKFXpq/iWf+9QXD+7Xjudv6U7uWSj+aqPhFoohzjuc/3sif/r2R7w3owDM39yNWpR91VPwiUcI5x7MfbuDF+Zu4JbEjvx3Rj1ox5nUs8YCKXyQKOOf4zQfrmbwgh5GD4nn6pj7EqPSjlopfJMI553jqvbWkLc5j7NDOPHlDb3wfrCfRKmDFb2ap+D6vd49zro9/3Vygh3+TpsAh51z/QGUQiXZlZY4n5q1m5pKtpFzWhce/01OlLwE9408DXgRmnFrhnLvt1LKZPQsUBPD4IlGtrMzx6DvZzM3M555hF/DwtT1U+gIEsPidcwvMLOFM9/k/wP1W4MpAHV8kmpWWOR56ayXvLNvO/Vd244FvdVfpy5e8GuO/HNjtnNtY2QZmNgmYBBAfHx+sXCJhr6S0jAffWMm8lTt48Fvduf+qC72OJCHGqwm8I4HZX7eBc26ycy7JOZfUqlWrIMUSCW/FpWXcP2c581bu4OFrL1LpyxkF/YzfzGKB7wOJwT62SCQ7UVLKj2Yt56O1u3n8Oz2ZcHlXryNJiPJiqOdqYL1zbpsHxxaJSEXFpfzw9WV8sn4Pv7yxN3cOTfA6koSwgA31mNls4HOgh5ltM7MU/123c5ZhHhGpuqLiUibOyGT+F3v43+/1VenLWQVyVs/IStaPC9QxRaJN4ckSJkzP5POc/fx+RD9uSerkdSQJA3rnrkiYOnqihPGvZpC55QDP3dqfmwZ08DqShAkVv0gYOlxUzLjUpazcVsALIwcwvF97ryNJGFHxi4SZgsJi7kxNZ+3Ow7w0aiDX9mnrdSQJMyp+kTBy8NhJxkxLZ+Puo/x1dCJX92rjdSQJQyp+kTCx7+gJxkxNJ2ffMSbfmcg3e7T2OpKEKRW/SBjYc6SI0VPSyT9YSOrYZC67sKXXkSSMqfhFQtyugiJGTVnCrsNFpN01iCFdW3gdScKcil8khG0/dJxRU5aw/+hJZowfRFJCc68jSQRQ8YuEqPwDhYycsoSC48XMSBnEwPhmXkeSCKHiFwlBW/YfY+TkJRw7WcrrEwbTr2NTryNJBFHxi4SYzXuPMmrKEopLHbMmDqZ3+zivI0mEUfGLhJCNu48wcko64Jg9cQg92jb2OpJEIBW/SIhYt/MwY6amUyvGmDVxKN1aN/I6kkQoFb9ICFi9vYAx09KpX7sWsyYOoUvLhl5Hkgjm1UcviojfyvxDjJqyhIZ1Ypk7aahKXwJOZ/wiHsracpBxqUtp2rA2sycOoWOzBl5Hkiig4hfxyNLcA9z16lJaN6nHrImDaRdX3+tIEiU01CPigcWb9zE2dSlt4+oxZ9IQlb4ElYpfJMgWbNjLXa9m0Kl5feZMGkqbJvW8jiRRRkM9IkE0f/0e7p6ZxQWtGjEzZRAtGtX1OpJEIRW/SJB8uGYX985axkVtm/BayiCaNqjjdSSJUip+kSD4IHsn981eTp8OcUwfP4i4+rW9jiRRTGP8IgH29xXb+dHs5fTv1JTXUlT64j2d8YsE0NtZ23jorZUkJzQndVwyDevqKSfe079CkQCZm7GVR97J5tILWjLlziTq16nldSQRQEM9IgExc8kWHn47m29c2IqpY1X6Elp0xi9Sw15dlMtT763l6p6teWn0QOrGqvQltKj4RWrQlAU5PP3+Oq7p3YY/jxxInVj9Ui2hR8UvUkNemr+JZ/71Bd/p144/3daf2rVU+hKaVPwi58k5x/Mfb+RP/97ITf3b84dbLiZWpS8hTMUvch6cczz74QZenL+JmxM78rsR/agVY17HEvlaKn6Rc+Sc47cfrOeVBTmMHNSJp2/qS4xKX8KAil+kmopLy/hg9S6mfZbDym0F3Dm0M09+t7dKX8KGil+kigqOFzM3Yytpi/LYUVBE15YN+c33+3J7cifMVPoSPlT8ImeRf6CQ1EW5vJGRz7GTpQzt2oJf3dSHK3q01lm+hCUVv8gZOOdYtvUgUz/L5V9rdhFjxg0Xt2f8ZV3o0yHO63gi50XFL1JOSWkZ/1yzi6mf5bIi/xBx9Wtzz7ALuHNoAm3j9ElZEhlU/CLA4aJi3sjI59VFeWw/dJyEFg341Y29GZHYkQZ19DSRyBKwf9FmlgoMB/Y45/qUW38fcC9QCvzDOffTQGUQOZv8A4WkLc5jbkY+R0+UMLhLc568oTdXXaTxe4lcgTyVSQNeBGacWmFmVwA3Ahc7506YWesAHl+kUsu2HmTaZ7l8sHonMWYM79eOlMu60rejxu8l8gWs+J1zC8wsocLqHwC/dc6d8G+zJ1DHF6mopLSMD9fuZupnOSzbeogm9WKZ9I0LGHtJZ9rF1fc6nkjQBHvwsjtwuZk9DRQBP3HOZZxpQzObBEwCiI+PD15CiThHioqZm5FP2uI8th08TucWDfjljb0ZMbCjPhFLolKw/9XHAs2BIUAy8IaZdXXOuYobOucmA5MBkpKSvnK/yNlsO1hI2qI85vjH7wd1ac4Tw3txVc82up6ORLVgF/824B1/0S81szKgJbA3yDkkgi3fepCpC3P55+pdAP7x+y7069jU42QioSHYxf834Apgvpl1B+oA+4KcQSJQaZnjwzW7mLowl6wtB2lcL5YJl3dh7NAE2jfV+L1IeYGczjkb+CbQ0sy2Ab8AUoFUM1sNnATGnmmYR6Sqjp4o8c2/X5xL/oHjxDdvwJPf7cUtSZ00fi9SiUDO6hlZyV1jAnVMiR7bDx1n+uI8Zqdv5ciJEpITmvHY9b34Vi+N34ucjU6JJKyszD/E1IW5vJ+9E4Dr+/rG7/t30vi9SFWp+CXklZY5Plq7m2kLc8jIO0jjurGkXNaFsZck0EHj9yLVpuKXkHXsRAlvZuaTuiiPrQcK6disPk8M78WtyZ1opPF7kXOmZ4+EnJ0Fx0lbnMes9K0cKSohsXMzHr3uIr7du63G70VqgIpfQsaqbYeYtjCXf6zaSZlzXOcfvx8Y38zraCIRRcUvniotc3y8bjdTF+ayNPcAjerGMu6SBMZekkCn5g28jicSkVT84onCkyW8lbWN1IW55O0vpEPT+jz+nZ7cltyJxvVqex1PJKKp+CWodhUUMf1z3/h9wfFiBsQ35aFrLuKa3m2IrRXjdTyRqKDil6BYvb2AqZ/l8H+nxu/7tGP8ZV1I7Kzxe5FgU/FLwJSVOT5ev4epn+WQ7h+/H3tJAuM0fi/iKRW/1LjCkyW8nbWN1EV55O479uX4/a3JnWii8XsRz6n4pcbsPlzE9MV5vO4fv7+4U1NeHDWAa3u31fi9SAhR8ct5W729gNSFuby3agelZY5rerdlwuW++fdmesOVSKhR8cs5KStzzP9iD1M/y+XznP00rFOLMUM6c9clXYhvofF7kVCm4pdqOX6ylLeX+ebf5+w7Rvu4evzs+ou4LTmeuPoavxcJByp+qZI9h4uY8fkWZqZv4VBhMf06xvHCyAFc16cttTV+LxJWVPzytdbuOMy0hbnMW7mdkjLHt3u1YcLlXUnqrPF7kXCl4pevKCtz/GfDXqYuzGHRpv00qFOL0YM7c9elCXRu0dDreCJynlT8QeCco8z5/nRAmXOc+qRh5/y3y22HA8fp+7gKy+X38d1XxX3825Xfp/z3W7fzMKkLc9m89xjt4urx6HUXcXtyPHENNH4vEikiuvhf+Hgj81buoMzXcKeV62nlV6GQTytUvlrcle3DmUo8DPXtEMfzt/fn+r7tNH4vEoEiuvhbN65LjzaNwcAAMyOm3LIZGL4/Y8ot+77Mvx3EfLl8+j4GxMT47uO07f67zJfHPH0fvvK9fcvw32PHWLljnjFP+cz+fWK+uu7UdjH+n40vl0/fB4OWDevSp0MTjd+LRLCILv7bB8Vz+6B4r2OIiIQU/R4vIhJlVPwiIlFGxS8iEmVU/CIiUUbFLyISZVT8IiJRRsUvIhJlVPwiIlHGnAv96wqY2V5gyznu3hLYV4NxaopyVY9yVY9yVU+o5oLzy9bZOdeq4sqwKP7zYWaZzrkkr3NUpFzVo1zVo1zVE6q5IDDZNNQjIhJlVPwiIlEmGop/stcBKqFc1aNc1aNc1ROquSAA2SJ+jF9ERE4XDWf8IiJSjopfRCTKhHXxm9m1ZvaFmW0ys0fOcH9dM5vrvz/dzBLK3feof/0XZnZNKOQyswQzO25mK/xfLwc51zfMbJmZlZjZzRXuG2tmG/1fY0MoV2m5x2tekHM9aGZrzWyVmX1sZp3L3efl4/V1ubx8vO4xs2z/sReaWa9y93n5fDxjLq+fj+W2G2FmzsySyq07v8fL99my4fcF1AI2A12BOsBKoFeFbX4IvOxfvh2Y61/u5d++LtDF/31qhUCuBGC1h49XAtAPmAHcXG59cyDH/2cz/3Izr3P57zvq4eN1BdDAv/yDcn+PXj9eZ8wVAo9Xk3LLNwD/9C97/XysLJenz0f/do2BBcASIKmmHq9wPuMfBGxyzuU4504Cc4AbK2xzIzDdv/wWcJX5Pkz2RmCOc+6Ecy4X2OT/fl7nCqSz5nLO5TnnVgFlFfa9BvjIOXfAOXcQ+Ai4NgRyBVJVcs13zhX6by4BOvqXvX68KssVSFXJdbjczYbAqZklnj4fvyZXIFWlJwB+BfwOKCq37rwfr3Au/g5Afrnb2/zrzriNc64EKABaVHFfL3IBdDGz5Wb2HzO7vIYyVTVXIPYN9PeuZ2aZZrbEzG6qoUznkisF+OAc9w1WLvD48TKze81sM/B74P7q7OtBLvDw+WhmA4FOzrl/VHffs4noD1sPQzuBeOfcfjNLBP5mZr0rnJHI6To757abWVfgEzPLds5tDmYAMxsDJAHDgnncs6kkl6ePl3PuJeAlMxsFPA7U6Osf56qSXJ49H80sBvgjMC4Q3z+cz/i3A53K3e7oX3fGbcwsFogD9ldx36Dn8v/qth/AOZeFb+yuexBzBWLfgH5v59x2/585wKfAgGDmMrOrgceAG5xzJ6qzrwe5PH+8ypkDnPqNw/PH60y5PH4+Ngb6AJ+aWR4wBJjnf4H3/B+vQLxwEYwvfL+t5OB7cePUiyO9K2xzL6e/iPqGf7k3p784kkPNvZh0PrlancqB70Wf7UDzYOUqt20aX31xNxffC5XN/MuhkKsZUNe/3BLYyBleIAvg3+MAfGVwYYX1nj5eX5PL68frwnLL3wUy/ctePx8ryxUSz0f/9p/y3xd3z/vxOu8fwMsv4Hpgg/8f+WP+db/Ed5YDUA94E9+LH0uBruX2fcy/3xfAdaGQCxgBrAFWAMuA7wY5VzK+8cJj+H4zWlNu3/H+vJuAu0IhF3AJkO1/EmQDKUHO9W9gt//vawUwL0QerzPmCoHH6/ly/77nU67oPH4+njGX18/HCtt+ir/4a+Lx0iUbRESiTDiP8YuIyDlQ8YuIRBkVv4hIlFHxi4hEGRW/iEiUUfGLAGb2mJmt8V/RcoWZDTazqeWvICkSKTSdU6KemQ3F9/b4bzrnTphZS6COc26Hx9FEAkJn/CLQDtjn/Jc2cM7tc87tMLNPT10D3cxSzGyDmS01sylm9qJ/fZqZ/dV/0bMcM/ummaWa2TozSzt1AP82mf7fKp7y4ocUOUXFLwIfAp38xf4XMzvtYmtm1h74Ob7rpVwKXFRh/2bAUOABYB7wHL631fc1s/7+bR5zziXh+1yBYWbWL2A/jchZqPgl6jnnjgKJwCRgLzDXzMaV22QQ8B/nu75+Mb7LbZT3nvONmWYDu51z2c65Mnxv90/wb3OrmS0DluP7T0GvHYhndFlmEcA5V4rveiifmlk21btc8KmrX5aVWz51O+y3eOgAAADLSURBVNbMugA/AZKdcwf9Q0D1zju0yDnSGb9EPTPrYWYXllvVH9hS7nYGvuGZZv7LaI+o5iGa4LvAXIGZtQGuO6/AIudJZ/wi0Aj4s5k1BUrwXVFzEr6PxcT5Prjkf/FdSfUAsB7fp6ZViXNupZkt9++XDyyq2fgi1aPpnCJVYGaNnHNH/Wf87wKpzrl3vc4lci401CNSNU+a2QpgNb4PVvmbx3lEzpnO+EVEoozO+EVEooyKX0Qkyqj4RUSijIpfRCTKqPhFRKLM/wNJtVxCafPKIwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import pylab\n",
    "import numpy as np\n",
    "def compute_price(sigma):\n",
    "    inputs = torch.tensor([[110.0, 100.0, 120.0, sigma, 0.1, 0.05]]).cuda()\n",
    "    x = model(inputs)\n",
    "    return x.item()\n",
    "sigmas = np.arange(0, 0.5, 0.1)\n",
    "prices = []\n",
    "for s in sigmas:\n",
    "    prices.append(compute_price(s))\n",
    "fig3 = pylab.plot(sigmas, prices)\n",
    "pylab.xlabel('Sigma')\n",
    "pylab.ylabel('Price')\n",
    "fig3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Given the prices `P`, the implied volatility is the root of the function `compute_price`. We can use bisection to find the root."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "implied volativity 0.18517351150512695 error 4.76837158203125e-06\n"
     ]
    }
   ],
   "source": [
    "def bisection_root(small, large, fun, target, EPS=1e-6):\n",
    "    if fun(large) - target < 0:\n",
    "        print('upper bound is too small')\n",
    "        return None\n",
    "    if fun(small) - target > 0:\n",
    "        print('lower bound is too large')\n",
    "        return None\n",
    "    while large - small > EPS:\n",
    "        mid = (large + small) / 2.0\n",
    "        if fun(mid) - target >= 0:\n",
    "            large = mid\n",
    "        else:\n",
    "            small = mid\n",
    "    mid = (large + small) / 2.0\n",
    "    return mid, abs(fun(mid) - target)\n",
    "quoted_price = 16.0\n",
    "sigma, err = bisection_root(0, 0.5, compute_price, quoted_price)\n",
    "print('implied volativity', sigma, 'error', err)     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
